diff options
Diffstat (limited to 'pkg/tcpip/network/ipv6/ipv6.go')
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 68 |
1 files changed, 55 insertions, 13 deletions
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 181c50cc7..ac67d4ac5 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -86,6 +86,8 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState } + + mld mldState } // NICNameFromID is a function that returns a stable name for the specified NIC, @@ -243,7 +245,7 @@ func (e *endpoint) Enable() *tcpip.Error { // (NDP NS) messages may be sent to the All-Nodes multicast group if the // source address of the NDP NS is the unspecified address, as per RFC 4861 // section 7.2.4. - if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { + if _, err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { return err } @@ -333,7 +335,7 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if _, err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } } @@ -378,7 +380,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { length := uint16(pkt.Size()) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -386,8 +388,8 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s NextHeader: uint8(params.Protocol), HopLimit: params.TTL, TrafficClass: params.TOS, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: srcAddr, + DstAddr: dstAddr, }) pkt.NetworkProtocolNumber = ProtocolNumber } @@ -442,7 +444,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) // iptables filtering. All packets that reach here are locally // generated. @@ -531,7 +533,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r, pb, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params) networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { @@ -1164,7 +1166,7 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { + if _, err := e.joinGroupLocked(snmc); err != nil { return nil, err } @@ -1221,7 +1223,7 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + if _, err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1387,20 +1389,56 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { // JoinGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup, but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { if !header.IsV6MulticastAddress(addr) { return false, tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + // TODO(gvisor.dev/issue/4916): Keep track of join count and MLD state in a + // single type. + joined, err := e.mu.addressableEndpointState.JoinGroup(addr) + if err != nil || !joined { + return joined, err + } + + // joinGroup only returns an error if we try to join a group twice, but we + // checked above to make sure that the group was newly joined. + if err := e.mld.joinGroup(addr); err != nil { + panic(fmt.Sprintf("e.mld.joinGroup(%s): %s", addr, err)) + } + + return true, nil } // LeaveGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup, but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { + left, err := e.mu.addressableEndpointState.LeaveGroup(addr) + if err != nil { + return left, err + } + + if left { + e.mld.leaveGroup(addr) + } + + return left, nil } // IsInGroup implements stack.GroupAddressableEndpoint. @@ -1482,6 +1520,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } e.mu.ndp.initializeTempAddrState() + e.mld.init(e, p.options.MLD) p.mu.Lock() defer p.mu.Unlock() @@ -1638,6 +1677,9 @@ type Options struct { // seed that is too small would reduce randomness and increase predictability, // defeating the purpose of temporary SLAAC addresses. TempIIDSeed []byte + + // MLD holds options for MLD. + MLD MLDOptions } // NewProtocolWithOptions returns an IPv6 network protocol. |