summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go57
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go355
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go85
-rw-r--r--pkg/tcpip/network/ipv6/BUILD15
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go31
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go50
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go68
-rw-r--r--pkg/tcpip/network/ipv6/mld.go175
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go90
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go14
-rw-r--r--pkg/tcpip/network/multicast_group_test.go793
13 files changed, 1312 insertions, 427 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index b38aff0b8..9ebf31b78 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -7,12 +7,14 @@ go_test(
size = "small",
srcs = [
"ip_test.go",
+ "multicast_group_test.go",
],
deps = [
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index c9bf117de..37f1822ca 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -51,6 +51,16 @@ const (
UnsolicitedReportIntervalMax = 10 * time.Second
)
+// IGMPOptions holds options for IGMP.
+type IGMPOptions struct {
+ // Enabled indicates whether IGMP will be performed.
+ //
+ // When enabled, IGMP may transmit IGMP report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // IGMP packets.
+ Enabled bool
+}
+
var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
// igmpState is the per-interface IGMP state.
@@ -58,7 +68,8 @@ var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
// igmpState.init() MUST be called after creating an IGMP state.
type igmpState struct {
// The IPv4 endpoint this igmpState is for.
- ep *endpoint
+ ep *endpoint
+ opts IGMPOptions
// igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
// RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
@@ -108,10 +119,11 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// init sets up an igmpState struct, and is required to be called before using
// a new igmpState.
-func (igmp *igmpState) init(ep *endpoint) {
+func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) {
igmp.mu.Lock()
defer igmp.mu.Unlock()
igmp.ep = ep
+ igmp.opts = opts
igmp.mu.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), igmp, UnsolicitedReportIntervalMax)
igmp.igmpV1Present = igmpV1PresentDefault
igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() {
@@ -189,6 +201,10 @@ func (igmp *igmpState) setV1Present(v bool) {
}
func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
+ if !igmp.opts.Enabled {
+ return
+ }
+
igmp.mu.Lock()
defer igmp.mu.Unlock()
@@ -206,6 +222,10 @@ func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxResp
}
func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
+ if !igmp.opts.Enabled {
+ return
+ }
+
igmp.mu.Lock()
defer igmp.mu.Unlock()
igmp.mu.genericMulticastProtocol.HandleReport(groupAddress)
@@ -226,11 +246,8 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
// TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
// rather we should select an appropriate local address.
- r := stack.Route{
- LocalAddress: header.IPv4Any,
- RemoteAddress: destAddress,
- }
- igmp.ep.addIPHeader(&r, pkt, stack.NetworkHeaderParams{
+ localAddr := header.IPv4Any
+ igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.IGMPProtocolNumber,
TTL: header.IGMPTTL,
TOS: stack.DefaultTOS,
@@ -239,7 +256,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
// TODO(b/162198658): set the ROUTER_ALERT option when sending Host
// Membership Reports.
sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
- if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
+ if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
return err
}
@@ -263,6 +280,26 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
// If the group already exists in the membership map, returns
// tcpip.ErrDuplicateAddress.
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error {
+ if !igmp.opts.Enabled {
+ return nil
+ }
+
+ // As per RFC 2236 section 6 page 10,
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // This is equivalent to not performing IGMP for the all-systems multicast
+ // address. Simply not performing IGMP when the group is added will prevent
+ // any work from being done on the all-systems multicast group when leaving
+ // the group or when query or report messages are received for it since the
+ // MGP state will not know about it.
+ if groupAddress == header.IPv4AllSystems {
+ return nil
+ }
+
igmp.mu.Lock()
defer igmp.mu.Unlock()
@@ -280,6 +317,10 @@ func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error {
// If the group does not exist in the membership map, this function will
// silently return.
func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) {
+ if !igmp.opts.Enabled {
+ return
+ }
+
igmp.mu.Lock()
defer igmp.mu.Unlock()
igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress)
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 4873a336f..d83b6c4a4 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -15,9 +15,7 @@
package ipv4_test
import (
- "fmt"
"testing"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -30,25 +28,11 @@ import (
)
const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- // endpointAddr = tcpip.Address("\x0a\x00\x00\x02")
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
nicID = 1
)
-var (
- // unsolicitedReportIntervalMaxTenthSec is the maximum amount of time the NIC
- // will wait before sending an unsolicited report after joining a multicast
- // group, in deciseconds.
- unsolicitedReportIntervalMaxTenthSec = func() uint8 {
- const decisecond = time.Second / 10
- if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
- panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
- }
- return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
- }()
-)
-
// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
// sent to the provided address with the passed fields set. Raises a t.Error if
// any field does not match.
@@ -75,7 +59,9 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
- IGMPEnabled: igmpEnabled,
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
})},
Clock: clock,
})
@@ -110,339 +96,6 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
})
}
-// TestIgmpDisabled tests that IGMP is not enabled with a default
-// stack.Options. This also tests that this NIC does not send the IGMP Join
-// Group for the All Hosts group it automatically joins when created.
-func TestIgmpDisabled(t *testing.T) {
- e, s, _ := createStack(t, false)
-
- // This NIC will join the All Hosts group when created. Verify that does not
- // send a report.
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got)
- }
- p, ok := e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
- }
-
- // Test joining a specific group explicitly and verify that no reports are
- // sent.
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4.ProtocolNumber, %d, %s) = %s", nicID, multicastAddr, err)
- }
-
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got)
- }
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
- }
-
- // Inject a General Membership Query, which is an IGMP Membership Query with
- // a zeroed Group Address (IPv4Any) to verify that it does not reach the
- // handler.
- createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, unsolicitedReportIntervalMaxTenthSec, header.IPv4Any)
-
- if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 0 {
- t.Fatalf("got Membership Queries received = %d, want = 0", got)
- }
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
- }
-}
-
-// TestIgmpReceivesIGMPMessages tests that the IGMP stack increments packet
-// counters when it receives properly formatted Membership Queries, Membership
-// Reports, and LeaveGroup Messages sent to this address. Note: test includes
-// IGMP header fields that are not explicitly tested in order to inject proper
-// IGMP packets.
-func TestIgmpReceivesIGMPMessages(t *testing.T) {
- tests := []struct {
- name string
- headerType header.IGMPType
- maxRespTime byte
- groupAddress tcpip.Address
- statCounter func(tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter
- }{
- {
- name: "General Membership Query",
- headerType: header.IGMPMembershipQuery,
- maxRespTime: unsolicitedReportIntervalMaxTenthSec,
- groupAddress: header.IPv4Any,
- statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter {
- return stats.MembershipQuery
- },
- },
- {
- name: "IGMPv1 Membership Report",
- headerType: header.IGMPv1MembershipReport,
- maxRespTime: 0,
- groupAddress: header.IPv4AllSystems,
- statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter {
- return stats.V1MembershipReport
- },
- },
- {
- name: "IGMPv2 Membership Report",
- headerType: header.IGMPv2MembershipReport,
- maxRespTime: 0,
- groupAddress: header.IPv4AllSystems,
- statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter {
- return stats.V2MembershipReport
- },
- },
- {
- name: "Leave Group",
- headerType: header.IGMPLeaveGroup,
- maxRespTime: 0,
- groupAddress: header.IPv4AllRoutersGroup,
- statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter {
- return stats.LeaveGroup
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, _ := createStack(t, true)
-
- createAndInjectIGMPPacket(e, test.headerType, test.maxRespTime, test.groupAddress)
-
- if got := test.statCounter(s.Stats().IGMP.PacketsReceived).Value(); got != 1 {
- t.Fatalf("got %s received = %d, want = 1", test.name, got)
- }
- })
- }
-}
-
-// TestIgmpJoinGroup tests that when explicitly joining a multicast group, the
-// IGMP stack schedules and sends correct Membership Reports.
-func TestIgmpJoinGroup(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- // Test joining a specific address explicitly and verify a Membership Report
- // is sent immediately.
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
-
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- // Verify the second Membership Report is sent after a random interval up to
- // the maximum unsolicited report interval.
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
-}
-
-// TestIgmpLeaveGroup tests that when leaving a previously joined multicast
-// group the IGMP enabled NIC sends the appropriate message.
-func TestIgmpLeaveGroup(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- // Join a group so that it can be left, validate the immediate Membership
- // Report is sent only to the multicast address joined.
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- // Verify the second Membership Report is sent after a random interval up to
- // the maximum unsolicited report interval, and is sent to the multicast
- // address being joined.
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- // Now that there are no packets queued and none scheduled to be sent, leave
- // the group.
- if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // Observe the Leave Group Message to verify that the Leave Group message is
- // sent to the All Routers group but that the message itself has the
- // multicast address being left.
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected LeaveGroup")
- }
- if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 1 {
- t.Fatalf("got LeaveGroup messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, header.IPv4AllRoutersGroup, header.IGMPLeaveGroup, 0, multicastAddr)
-}
-
-// TestIgmpJoinLeaveGroup tests that when leaving a previously joined multicast
-// group before the Unsolicited Report Interval cancels the second membership
-// report.
-func TestIgmpJoinLeaveGroup(t *testing.T) {
- _, s, clock := createStack(t, true)
-
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // Verify that this NIC sent a Membership Report for only the group just
- // joined.
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
-
- if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // Wait for the standard IGMP Unsolicited Report Interval duration before
- // verifying that the unsolicited Membership Report was sent after leaving
- // the group.
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
-}
-
-// TestIgmpMembershipQueryReport tests the handling of both incoming IGMP
-// Membership Queries and outgoing Membership Reports.
-func TestIgmpMembershipQueryReport(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
-
- // Inject a General Membership Query, which is an IGMP Membership Query with
- // a zeroed Group Address (IPv4Any) with the shortened Max Response Time.
- const maxRespTimeDS = 10
- createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, maxRespTimeDS, header.IPv4Any)
-
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(header.DecisecondToDuration(maxRespTimeDS))
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 3 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 3", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
-}
-
-// TestIgmpMultipleHosts tests the handling of IGMP Leave when we are not the
-// most recent IGMP host to join a multicast network.
-func TestIgmpMultipleHosts(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- // Inject another Host's Join Group message so that this host is not the
- // latest to send the report. Set Max Response Time to 0 for Membership
- // Reports.
- createAndInjectIGMPPacket(e, header.IGMPv2MembershipReport, 0, multicastAddr)
-
- if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // Wait to be sure that no Leave Group messages were sent up to the max
- // unsolicited report interval since it was not the last host to join this
- // group.
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 0 {
- t.Fatalf("got LeaveGroup messages sent = %d, want = 0", got)
- }
-}
-
// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is
// present on the network. The IGMP stack will then send IGMPv1 Membership
// reports for backwards compatibility.
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 7c759be9a..be9c8e2f9 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -95,7 +95,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
protocol: p,
}
e.mu.addressableEndpointState.Init(e)
- e.igmp.init(e)
+ e.igmp.init(e, p.options.IGMP)
return e
}
@@ -126,7 +126,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
// all-systems multicast group.
- _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
+ _, err = e.joinGroupLocked(header.IPv4AllSystems)
return err
}
@@ -164,7 +164,7 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ if _, err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
@@ -200,7 +200,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
hdrLen := header.IPv4MinimumSize
var opts header.IPv4Options
if params.Options != nil {
@@ -221,15 +221,15 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
- id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
TotalLength: length,
ID: uint16(id),
TTL: params.TTL,
TOS: params.TOS,
Protocol: uint8(params.Protocol),
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
Options: opts,
})
ip.SetChecksum(^ip.CalculateChecksum())
@@ -261,7 +261,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -349,7 +349,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params)
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
@@ -463,7 +463,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// non-atomic datagrams, so assign an ID to all such datagrams
// according to the definition given in RFC 6864 section 4.
if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
- ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
}
@@ -706,10 +706,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
if p == header.IGMPProtocolNumber {
- if e.protocol.options.IGMPEnabled {
- e.igmp.handleIGMP(pkt)
- }
- // Nothing further to do with an IGMP packet, even if IGMP is not enabled.
+ e.igmp.handleIGMP(pkt)
return
}
if opts := h.Options(); len(opts) != 0 {
@@ -837,32 +834,55 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
// JoinGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup, but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
if !header.IsV4MulticastAddress(addr) {
return false, tcpip.ErrBadAddress
}
+ // TODO(gvisor.dev/issue/4916): Keep track of join count and IGMP state in a
+ // single type.
+ joined, err := e.mu.addressableEndpointState.JoinGroup(addr)
+ if err != nil || !joined {
+ return joined, err
+ }
- e.mu.Lock()
- defer e.mu.Unlock()
-
- joinedGroup, err := e.mu.addressableEndpointState.JoinGroup(addr)
- if err == nil && joinedGroup && e.protocol.options.IGMPEnabled {
- _ = e.igmp.joinGroup(addr)
+ // joinGroup only returns an error if we try to join a group twice, but we
+ // checked above to make sure that the group was newly joined.
+ if err := e.igmp.joinGroup(addr); err != nil {
+ panic(fmt.Sprintf("e.igmp.joinGroup(%s): %s", addr, err))
}
- return joinedGroup, err
+ return true, nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup, but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
+ left, err := e.mu.addressableEndpointState.LeaveGroup(addr)
+ if err != nil {
+ return left, err
+ }
- leftGroup, err := e.mu.addressableEndpointState.LeaveGroup(addr)
- if err == nil && leftGroup && e.protocol.options.IGMPEnabled {
+ if left {
e.igmp.leaveGroup(addr)
}
- return leftGroup, err
+ return left, nil
}
// IsInGroup implements stack.GroupAddressableEndpoint.
@@ -1021,20 +1041,19 @@ func addressToUint32(addr tcpip.Address) uint32 {
return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
}
-// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number and a 32-bit number to
-// generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- a := addressToUint32(r.LocalAddress)
- b := addressToUint32(r.RemoteAddress)
+// hashRoute calculates a hash value for the given source/destination pair using
+// the addresses, transport protocol number and a 32-bit number to generate the
+// hash.
+func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
+ a := addressToUint32(srcAddr)
+ b := addressToUint32(dstAddr)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
// Options holds options to configure a new protocol.
type Options struct {
- // IGMPEnabled indicates whether incoming IGMP packets will be handled and if
- // this endpoint will transmit IGMP packets on IGMP related events.
- IGMPEnabled bool
+ // IGMP holds options for IGMP.
+ IGMP IGMPOptions
}
// NewProtocolWithOptions returns an IPv4 network protocol.
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 0ac24a6fb..5e75c8740 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -8,6 +8,7 @@ go_library(
"dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "mld.go",
"ndp.go",
],
visibility = ["//visibility:public"],
@@ -19,6 +20,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -49,3 +51,16 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "ipv6_x_test",
+ size = "small",
+ srcs = ["mld_test.go"],
+ deps = [
+ ":ipv6",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 386d98a29..0fb98c578 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
- switch h.Type() {
+ switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
@@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na := header.NDPNeighborAdvert(packet.MessageBody())
// As per RFC 4861 section 7.2.4:
//
@@ -644,8 +644,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
+ case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
+ var handler func(header.MLD)
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ received.MulticastListenerQuery.Increment()
+ handler = e.mld.handleMulticastListenerQuery
+ case header.ICMPv6MulticastListenerReport:
+ received.MulticastListenerReport.Increment()
+ handler = e.mld.handleMulticastListenerReport
+ case header.ICMPv6MulticastListenerDone:
+ received.MulticastListenerDone.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ if handler != nil {
+ handler(header.MLD(payload.ToView()))
+ }
+
default:
- received.Invalid.Increment()
+ received.Unrecognized.Increment()
}
}
@@ -681,7 +704,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns := header.NDPNeighborSolicit(packet.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 9bc02d851..ead0739e5 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -271,6 +271,22 @@ func TestICMPCounts(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -413,6 +429,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -1543,7 +1575,7 @@ func TestPacketQueing(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
@@ -1592,7 +1624,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1612,7 +1644,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1629,7 +1661,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1645,7 +1677,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1662,7 +1694,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1683,7 +1715,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1702,7 +1734,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1722,7 +1754,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 181c50cc7..ac67d4ac5 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -86,6 +86,8 @@ type endpoint struct {
addressableEndpointState stack.AddressableEndpointState
ndp ndpState
}
+
+ mld mldState
}
// NICNameFromID is a function that returns a stable name for the specified NIC,
@@ -243,7 +245,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// (NDP NS) messages may be sent to the All-Nodes multicast group if the
// source address of the NDP NS is the unspecified address, as per RFC 4861
// section 7.2.4.
- if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
+ if _, err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil {
return err
}
@@ -333,7 +335,7 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if _, err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
}
@@ -378,7 +380,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
length := uint16(pkt.Size())
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -386,8 +388,8 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
NextHeader: uint8(params.Protocol),
HopLimit: params.TTL,
TrafficClass: params.TOS,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
})
pkt.NetworkProtocolNumber = ProtocolNumber
}
@@ -442,7 +444,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -531,7 +533,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r, pb, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params)
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
@@ -1164,7 +1166,7 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
+ if _, err := e.joinGroupLocked(snmc); err != nil {
return nil, err
}
@@ -1221,7 +1223,7 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ if _, err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
@@ -1387,20 +1389,56 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
// JoinGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup, but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
if !header.IsV6MulticastAddress(addr) {
return false, tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ // TODO(gvisor.dev/issue/4916): Keep track of join count and MLD state in a
+ // single type.
+ joined, err := e.mu.addressableEndpointState.JoinGroup(addr)
+ if err != nil || !joined {
+ return joined, err
+ }
+
+ // joinGroup only returns an error if we try to join a group twice, but we
+ // checked above to make sure that the group was newly joined.
+ if err := e.mld.joinGroup(addr); err != nil {
+ panic(fmt.Sprintf("e.mld.joinGroup(%s): %s", addr, err))
+ }
+
+ return true, nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup, but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
+ left, err := e.mu.addressableEndpointState.LeaveGroup(addr)
+ if err != nil {
+ return left, err
+ }
+
+ if left {
+ e.mld.leaveGroup(addr)
+ }
+
+ return left, nil
}
// IsInGroup implements stack.GroupAddressableEndpoint.
@@ -1482,6 +1520,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
}
e.mu.ndp.initializeTempAddrState()
+ e.mld.init(e, p.options.MLD)
p.mu.Lock()
defer p.mu.Unlock()
@@ -1638,6 +1677,9 @@ type Options struct {
// seed that is too small would reduce randomness and increase predictability,
// defeating the purpose of temporary SLAAC addresses.
TempIIDSeed []byte
+
+ // MLD holds options for MLD.
+ MLD MLDOptions
}
// NewProtocolWithOptions returns an IPv6 network protocol.
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
new file mode 100644
index 000000000..b16a1afb0
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -0,0 +1,175 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited MLD reports.
+ //
+ // Obtained from RFC 2710 Section 7.10.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// MLDOptions holds options for MLD.
+type MLDOptions struct {
+ // Enabled indicates whether MLD will be performed.
+ //
+ // When enabled, MLD may transmit MLD report and done messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // MLD packets.
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*mldState)(nil)
+
+// mldState is the per-interface MLD state.
+//
+// mldState.init MUST be called to initialize the MLD state.
+type mldState struct {
+ // The IPv6 endpoint this mldState is for.
+ ep *endpoint
+ opts MLDOptions
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+ return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+}
+
+// init sets up an mldState struct, and is required to be called before using
+// a new mldState.
+func (mld *mldState) init(ep *endpoint, opts MLDOptions) {
+ mld.ep = ep
+ mld.opts = opts
+ mld.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), mld, UnsolicitedReportIntervalMax)
+}
+
+func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
+ if !mld.opts.Enabled {
+ return
+ }
+
+ mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
+}
+
+func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
+ if !mld.opts.Enabled {
+ return
+ }
+
+ mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress())
+}
+
+// joinGroup handles joining a new group and sending and scheduling the required
+// messages.
+//
+// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+func (mld *mldState) joinGroup(groupAddress tcpip.Address) *tcpip.Error {
+ if !mld.opts.Enabled {
+ return nil
+ }
+
+ // As per RFC 2710 section 5 page 10,
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ //
+ // This is equivalent to not performing MLD for the all-nodes multicast
+ // address. Simply not performing MLD when the group is added will prevent
+ // any work from being done on the all-nodes multicast group when leaving the
+ // group or when query or report messages are received for it since the MGP
+ // state will not know about it.
+ if groupAddress == header.IPv6AllNodesMulticastAddress {
+ return nil
+ }
+
+ // JoinGroup returns false if we have already joined the group.
+ if !mld.genericMulticastProtocol.JoinGroup(groupAddress) {
+ return tcpip.ErrDuplicateAddress
+ }
+ return nil
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Done message, if
+// required.
+//
+// If the group is not joined, this function will do nothing.
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) {
+ if !mld.opts.Enabled {
+ return
+ }
+
+ mld.genericMulticastProtocol.LeaveGroup(groupAddress)
+}
+
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error {
+ sentStats := mld.ep.protocol.stack.Stats().ICMP.V6PacketsSent
+ var mldStat *tcpip.StatCounter
+ switch mldType {
+ case header.ICMPv6MulticastListenerReport:
+ mldStat = sentStats.MulticastListenerReport
+ case header.ICMPv6MulticastListenerDone:
+ mldStat = sentStats.MulticastListenerDone
+ default:
+ panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
+ }
+
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize))
+ icmp.SetType(mldType)
+ header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
+ // TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
+ // rather we should select an appropriate local address.
+ localAddress := header.IPv6Any
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()),
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+
+ mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.MLDHopLimit,
+ })
+ // TODO(b/162198658): set the ROUTER_ALERT option when sending Host
+ // Membership Reports.
+ if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return err
+ }
+ mldStat.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
new file mode 100644
index 000000000..5677bdd54
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+)
+
+func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // The stack will join an address's solicited node multicast address when
+ // an address is added. An MLD report message should be sent for the
+ // solicited-node group.
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ }
+ {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a report message to be sent")
+ }
+ snmc := header.SolicitedNodeAddr(addr1)
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())),
+ checker.DstAddr(snmc),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(snmc),
+ ),
+ )
+ }
+
+ // The stack will leave an address's solicited node multicast address when
+ // an address is removed. An MLD done message should be sent for the
+ // solicited-node group.
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ }
+ {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a done message to be sent")
+ }
+ snmc := header.SolicitedNodeAddr(addr1)
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(snmc),
+ ),
+ )
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index c138358af..b83e9d931 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -773,7 +773,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
icmpData.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmpData.MessageBody())
ns.SetTargetAddress(addr)
icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
@@ -1944,7 +1944,7 @@ func (ndp *ndpState) startSolicitingRouters() {
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
icmpData := header.ICMPv6(buffer.NewView(payloadSize))
icmpData.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs := header.NDPRouterSolicit(icmpData.MessageBody())
rs.Options().Serialize(optsSerializer)
icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 9fbd0d336..56dc4a6ec 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -205,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -311,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -591,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(nicAddr)
opts := ns.Options()
opts.Serialize(test.nsOpts)
@@ -672,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(test.nsSrc)
@@ -777,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -890,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -1346,7 +1346,7 @@ func TestRouterAdvertValidation(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
+ copy(pkt.MessageBody(), test.ndpPayload)
payloadLength := hdr.UsedLength()
pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
new file mode 100644
index 000000000..12f9fbe82
--- /dev/null
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -0,0 +1,793 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+ ipv4MulticastAddr = tcpip.Address("\xe0\x00\x00\x03")
+ ipv6MulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+
+ igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
+ igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
+ igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
+ igmpLeaveGroup = uint8(header.IGMPLeaveGroup)
+ mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
+ mldReport = uint8(header.ICMPv6MulticastListenerReport)
+ mldDone = uint8(header.ICMPv6MulticastListenerDone)
+)
+
+var (
+ // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
+ // NIC will wait before sending an unsolicited report after joining a
+ // multicast group, in deciseconds.
+ unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
+ const decisecond = time.Second / 10
+ if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
+ panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
+ }
+ return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
+ }()
+)
+
+// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
+// sent to the provided address with the passed fields set.
+func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv6(t, payload,
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set.
+func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IGMP(
+ checker.IGMPType(header.IGMPType(igmpType)),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ // Create an endpoint of queue size 1, since no more than 1 packets are ever
+ // queued in the tests in this file.
+ e := channel.New(1, 1280, linkAddr)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: mgpEnabled,
+ },
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: mgpEnabled,
+ },
+ }),
+ },
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ return e, s, clock
+}
+
+// createAndInjectIGMPPacket creates and injects an IGMP packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: header.IGMPTTL,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(header.IGMPType(igmpType))
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// createAndInjectMLDPacket creates and injects an MLD packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
+ icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize
+ buf := buffer.NewView(header.IPv6MinimumSize + icmpSize)
+
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ HopLimit: header.MLDHopLimit,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ icmp := header.ICMPv6(buf[header.IPv6MinimumSize:])
+ icmp.SetType(header.ICMPv6Type(mldType))
+ mld := header.MLD(icmp.MessageBody())
+ mld.SetMaximumResponseDelay(uint16(maxRespDelay))
+ mld.SetMulticastAddress(groupAddress)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+
+ e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestMGPDisabled tests that the multicast group protocol is not enabled by
+// default.
+func TestMGPDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, false)
+
+ // This NIC may join multicast groups when it is enabled but since MGP is
+ // disabled, no reports should be sent.
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportState.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportState.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
+ }
+
+ // Inject a general query message. This should only trigger a report to be
+ // sent if the MGP was enabled.
+ test.rxQuery(e)
+ if got := test.receivedQueryStat(s).Value(); got != 1 {
+ t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
+ }
+ })
+ }
+}
+
+func TestMGPReceiveCounters(t *testing.T) {
+ tests := []struct {
+ name string
+ headerType uint8
+ maxRespTime byte
+ groupAddress tcpip.Address
+ statCounter func(*stack.Stack) *tcpip.StatCounter
+ rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address)
+ }{
+ {
+ name: "IGMP Membership Query",
+ headerType: igmpMembershipQuery,
+ maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec,
+ groupAddress: header.IPv4Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv1 Membership Report",
+ headerType: igmpv1MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V1MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv2 Membership Report",
+ headerType: igmpv2MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V2MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMP Leave Group",
+ headerType: igmpLeaveGroup,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllRoutersGroup,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.LeaveGroup
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "MLD Query",
+ headerType: mldQuery,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsReceived.MulticastListenerQuery
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Report",
+ headerType: mldReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsReceived.MulticastListenerReport
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Done",
+ headerType: mldDone,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsReceived.MulticastListenerDone
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, _ := createStack(t, true)
+
+ test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
+ if got := test.statCounter(s).Value(); got != 1 {
+ t.Fatalf("got %s received = %d, want = 1", test.name, got)
+ }
+ })
+ }
+}
+
+// TestMGPJoinGroup tests that when explicitly joining a multicast group, the
+// stack schedules and sends correct Membership Reports.
+func TestMGPJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsReceived.MulticastListenerQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ // Test joining a specific address explicitly and verify a Report is sent
+ // immediately.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 1 {
+ t.Errorf("got sentReportState.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Verify the second report is sent by the maximum unsolicited response
+ // interval.
+ p, ok := e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ if got := sentReportStat.Value(); got != 2 {
+ t.Errorf("got sentReportState.Value() = %d, want = 2", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPLeaveGroup tests that when leaving a previously joined multicast
+// group the stack sends a leave/done message.
+func TestMGPLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ validateLeave func(*testing.T, channel.PacketInfo)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := test.sentReportStat(s).Value(); got != 1 {
+ t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving the group should trigger an leave/done message to be sent.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ if got := test.sentLeaveStat(s).Value(); got != 1 {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a leave message to be sent")
+ } else {
+ test.validateLeave(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that a report is sent in response to query
+// messages.
+func TestMGPQueryMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ multicastAddr tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Unspecified",
+ multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
+ expectReport: true,
+ },
+ {
+ name: "Specified",
+ multicastAddr: test.multicastAddr,
+ expectReport: true,
+ },
+ {
+ name: "Specified other address",
+ multicastAddr: func() tcpip.Address {
+ addrBytes := []byte(test.multicastAddr)
+ addrBytes[len(addrBytes)-1]++
+ return tcpip.Address(addrBytes)
+ }(),
+ expectReport: false,
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ for i := uint64(1); i <= 2; i++ {
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != i {
+ t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected %d-th report message to be sent", i)
+ } else {
+ test.validateReport(t, p)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should not send any more packets until a query.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ // Receive a query message which should trigger a report to be sent at
+ // some time before the maximum response time if the report is
+ // targeted at the host.
+ const maxRespTime = 100
+ test.rxQuery(e, maxRespTime, subTest.multicastAddr)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p.Pkt)
+ }
+
+ if subTest.expectReport {
+ clock.Advance(test.maxRespTimeToDuration(maxRespTime))
+ if got := sentReportStat.Value(); got != 3 {
+ t.Errorf("got sentReportState.Value() = %d, want = 3", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that no further reports or leave/done messages
+// are sent after receiving a report.
+func TestMGPReportMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ rxReport func(*channel.Endpoint)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerDone
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 1 {
+ t.Errorf("got sentReportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Receiving a report for a group we joined should cancel any further
+ // reports.
+ test.rxReport(e)
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 1 {
+ t.Errorf("got sentReportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("sent unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving a group after getting a report should not send a leave/done
+ // message.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ clock.Advance(time.Hour)
+ if got := test.sentLeaveStat(s).Value(); got != 0 {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}