diff options
Diffstat (limited to 'pkg/tcpip/network/ipv6')
-rw-r--r-- | pkg/tcpip/network/ipv6/BUILD | 15 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 31 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 50 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 175 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld_test.go | 90 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 14 |
8 files changed, 412 insertions, 35 deletions
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 0ac24a6fb..5e75c8740 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -8,6 +8,7 @@ go_library( "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "mld.go", "ndp.go", ], visibility = ["//visibility:public"], @@ -19,6 +20,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -49,3 +51,16 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "ipv6_x_test", + size = "small", + srcs = ["mld_test.go"], + deps = [ + ":ipv6", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 386d98a29..0fb98c578 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } // TODO(b/112892170): Meaningfully handle all ICMP types. - switch h.Type() { + switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) @@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(packet.NDPPayload()) + na := header.NDPNeighborAdvert(packet.MessageBody()) // As per RFC 4861 section 7.2.4: // @@ -644,8 +644,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } + case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: + var handler func(header.MLD) + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + received.MulticastListenerQuery.Increment() + handler = e.mld.handleMulticastListenerQuery + case header.ICMPv6MulticastListenerReport: + received.MulticastListenerReport.Increment() + handler = e.mld.handleMulticastListenerReport + case header.ICMPv6MulticastListenerDone: + received.MulticastListenerDone.Increment() + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) + } + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { + received.Invalid.Increment() + return + } + + if handler != nil { + handler(header.MLD(payload.ToView())) + } + default: - received.Invalid.Increment() + received.Unrecognized.Increment() } } @@ -681,7 +704,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(packet.NDPPayload()) + ns := header.NDPNeighborSolicit(packet.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 9bc02d851..ead0739e5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -271,6 +271,22 @@ func TestICMPCounts(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -413,6 +429,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -1543,7 +1575,7 @@ func TestPacketQueing(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) @@ -1592,7 +1624,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1612,7 +1644,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1629,7 +1661,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1645,7 +1677,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1662,7 +1694,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1683,7 +1715,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1702,7 +1734,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1722,7 +1754,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 181c50cc7..ac67d4ac5 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -86,6 +86,8 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState } + + mld mldState } // NICNameFromID is a function that returns a stable name for the specified NIC, @@ -243,7 +245,7 @@ func (e *endpoint) Enable() *tcpip.Error { // (NDP NS) messages may be sent to the All-Nodes multicast group if the // source address of the NDP NS is the unspecified address, as per RFC 4861 // section 7.2.4. - if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { + if _, err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { return err } @@ -333,7 +335,7 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if _, err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } } @@ -378,7 +380,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { length := uint16(pkt.Size()) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -386,8 +388,8 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s NextHeader: uint8(params.Protocol), HopLimit: params.TTL, TrafficClass: params.TOS, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: srcAddr, + DstAddr: dstAddr, }) pkt.NetworkProtocolNumber = ProtocolNumber } @@ -442,7 +444,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) // iptables filtering. All packets that reach here are locally // generated. @@ -531,7 +533,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r, pb, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params) networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { @@ -1164,7 +1166,7 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { + if _, err := e.joinGroupLocked(snmc); err != nil { return nil, err } @@ -1221,7 +1223,7 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + if _, err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1387,20 +1389,56 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { // JoinGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup, but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { if !header.IsV6MulticastAddress(addr) { return false, tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + // TODO(gvisor.dev/issue/4916): Keep track of join count and MLD state in a + // single type. + joined, err := e.mu.addressableEndpointState.JoinGroup(addr) + if err != nil || !joined { + return joined, err + } + + // joinGroup only returns an error if we try to join a group twice, but we + // checked above to make sure that the group was newly joined. + if err := e.mld.joinGroup(addr); err != nil { + panic(fmt.Sprintf("e.mld.joinGroup(%s): %s", addr, err)) + } + + return true, nil } // LeaveGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup, but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { + left, err := e.mu.addressableEndpointState.LeaveGroup(addr) + if err != nil { + return left, err + } + + if left { + e.mld.leaveGroup(addr) + } + + return left, nil } // IsInGroup implements stack.GroupAddressableEndpoint. @@ -1482,6 +1520,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } e.mu.ndp.initializeTempAddrState() + e.mld.init(e, p.options.MLD) p.mu.Lock() defer p.mu.Unlock() @@ -1638,6 +1677,9 @@ type Options struct { // seed that is too small would reduce randomness and increase predictability, // defeating the purpose of temporary SLAAC addresses. TempIIDSeed []byte + + // MLD holds options for MLD. + MLD MLDOptions } // NewProtocolWithOptions returns an IPv6 network protocol. diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go new file mode 100644 index 000000000..b16a1afb0 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld.go @@ -0,0 +1,175 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited MLD reports. + // + // Obtained from RFC 2710 Section 7.10. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// MLDOptions holds options for MLD. +type MLDOptions struct { + // Enabled indicates whether MLD will be performed. + // + // When enabled, MLD may transmit MLD report and done messages when + // joining and leaving multicast groups respectively, and handle incoming + // MLD packets. + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*mldState)(nil) + +// mldState is the per-interface MLD state. +// +// mldState.init MUST be called to initialize the MLD state. +type mldState struct { + // The IPv6 endpoint this mldState is for. + ep *endpoint + opts MLDOptions + + genericMulticastProtocol ip.GenericMulticastProtocolState +} + +// SendReport implements ip.MulticastGroupProtocol. +func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error { + return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) +} + +// SendLeave implements ip.MulticastGroupProtocol. +func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) +} + +// init sets up an mldState struct, and is required to be called before using +// a new mldState. +func (mld *mldState) init(ep *endpoint, opts MLDOptions) { + mld.ep = ep + mld.opts = opts + mld.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), mld, UnsolicitedReportIntervalMax) +} + +func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { + if !mld.opts.Enabled { + return + } + + mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) +} + +func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { + if !mld.opts.Enabled { + return + } + + mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress()) +} + +// joinGroup handles joining a new group and sending and scheduling the required +// messages. +// +// If the group is already joined, returns tcpip.ErrDuplicateAddress. +func (mld *mldState) joinGroup(groupAddress tcpip.Address) *tcpip.Error { + if !mld.opts.Enabled { + return nil + } + + // As per RFC 2710 section 5 page 10, + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + // + // This is equivalent to not performing MLD for the all-nodes multicast + // address. Simply not performing MLD when the group is added will prevent + // any work from being done on the all-nodes multicast group when leaving the + // group or when query or report messages are received for it since the MGP + // state will not know about it. + if groupAddress == header.IPv6AllNodesMulticastAddress { + return nil + } + + // JoinGroup returns false if we have already joined the group. + if !mld.genericMulticastProtocol.JoinGroup(groupAddress) { + return tcpip.ErrDuplicateAddress + } + return nil +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Done message, if +// required. +// +// If the group is not joined, this function will do nothing. +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) { + if !mld.opts.Enabled { + return + } + + mld.genericMulticastProtocol.LeaveGroup(groupAddress) +} + +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { + sentStats := mld.ep.protocol.stack.Stats().ICMP.V6PacketsSent + var mldStat *tcpip.StatCounter + switch mldType { + case header.ICMPv6MulticastListenerReport: + mldStat = sentStats.MulticastListenerReport + case header.ICMPv6MulticastListenerDone: + mldStat = sentStats.MulticastListenerDone + default: + panic(fmt.Sprintf("unrecognized mld type = %d", mldType)) + } + + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) + icmp.SetType(mldType) + header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) + // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, + // rather we should select an appropriate local address. + localAddress := header.IPv6Any + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()), + Data: buffer.View(icmp).ToVectorisedView(), + }) + + mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.MLDHopLimit, + }) + // TODO(b/162198658): set the ROUTER_ALERT option when sending Host + // Membership Reports. + if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sentStats.Dropped.Increment() + return err + } + mldStat.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go new file mode 100644 index 000000000..5677bdd54 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" +) + +func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + // The stack will join an address's solicited node multicast address when + // an address is added. An MLD report message should be sent for the + // solicited-node group. + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err) + } + { + p, ok := e.Read() + if !ok { + t.Fatal("expected a report message to be sent") + } + snmc := header.SolicitedNodeAddr(addr1) + checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), + checker.DstAddr(snmc), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(snmc), + ), + ) + } + + // The stack will leave an address's solicited node multicast address when + // an address is removed. An MLD done message should be sent for the + // solicited-node group. + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + } + { + p, ok := e.Read() + if !ok { + t.Fatal("expected a done message to be sent") + } + snmc := header.SolicitedNodeAddr(addr1) + checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(1), + checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(snmc), + ), + ) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index c138358af..b83e9d931 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -773,7 +773,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) icmpData.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) + ns := header.NDPNeighborSolicit(icmpData.MessageBody()) ns.SetTargetAddress(addr) icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) @@ -1944,7 +1944,7 @@ func (ndp *ndpState) startSolicitingRouters() { payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(icmpData.NDPPayload()) + rs := header.NDPRouterSolicit(icmpData.MessageBody()) rs.Options().Serialize(optsSerializer) icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 9fbd0d336..56dc4a6ec 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -205,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -311,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -591,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(nicAddr) opts := ns.Options() opts.Serialize(test.nsOpts) @@ -672,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(test.nsSrc) @@ -777,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -890,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -1346,7 +1346,7 @@ func TestRouterAdvertValidation(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) + copy(pkt.MessageBody(), test.ndpPayload) payloadLength := hdr.UsedLength() pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) |