From fbc4a8dbd1d1939854dbeb6ccfd4c6267f85c9ec Mon Sep 17 00:00:00 2001 From: Ryan Heacock Date: Thu, 19 Nov 2020 18:13:24 -0800 Subject: Perform IGMPv2 when joining IPv4 multicast groups Added headers, stats, checksum parsing capabilities from RFC 2236 describing IGMPv2. IGMPv2 state machine is implemented for each condition, sending and receiving IGMP Membership Reports and Leave Group messages with backwards compatibility with IGMPv1 routers. Test: * Implemented igmp header parser and checksum calculator in header/igmp_test.go * ipv4/igmp_test.go tests incoming and outgoing IGMP messages and pathways. * Added unit test coverage for IGMPv2 RFC behavior + IGMPv1 backwards compatibility in ipv4/igmp_test.go. Fixes #4682 PiperOrigin-RevId: 343408809 --- pkg/tcpip/checker/checker.go | 68 +++++ pkg/tcpip/header/BUILD | 2 + pkg/tcpip/header/igmp.go | 166 ++++++++++++ pkg/tcpip/header/igmp_test.go | 101 ++++++++ pkg/tcpip/header/ipv4.go | 3 + pkg/tcpip/network/ipv4/BUILD | 6 +- pkg/tcpip/network/ipv4/igmp.go | 398 +++++++++++++++++++++++++++++ pkg/tcpip/network/ipv4/igmp_test.go | 491 ++++++++++++++++++++++++++++++++++++ pkg/tcpip/network/ipv4/ipv4.go | 60 ++++- pkg/tcpip/stack/stack.go | 1 - pkg/tcpip/tcpip.go | 57 +++++ pkg/tcpip/transport/tcp/endpoint.go | 2 +- 12 files changed, 1341 insertions(+), 14 deletions(-) create mode 100644 pkg/tcpip/header/igmp.go create mode 100644 pkg/tcpip/header/igmp_test.go create mode 100644 pkg/tcpip/network/ipv4/igmp.go create mode 100644 pkg/tcpip/network/ipv4/igmp_test.go diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 81f762e10..1c82c2c3b 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1227,3 +1227,71 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker { ndpOptions(t, rs.Options(), opts) } } + +// IGMP checks the validity and properties of the given IGMP packet. It is +// expected to be used in conjunction with other IGMP transport checkers for +// specific properties. +func IGMP(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) + } + + igmp := header.IGMP(last.Payload()) + for _, f := range checkers { + f(t, igmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// IGMPType creates a checker that checks the IGMP Type field. +func IGMPType(want header.IGMPType) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.Type(); got != want { + t.Errorf("got igmp.Type() = %d, want = %d", got, want) + } + } +} + +// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. +func IGMPMaxRespTime(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.MaxRespTime(); got != want { + t.Errorf("got igmp.MaxRespTime() = %d, want = %d", got, want) + } + } +} + +// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. +func IGMPGroupAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.GroupAddress(); got != want { + t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index ca6cbe41a..144093c3a 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -11,6 +11,7 @@ go_library( "gue.go", "icmpv4.go", "icmpv6.go", + "igmp.go", "interfaces.go", "ipv4.go", "ipv6.go", @@ -40,6 +41,7 @@ go_test( size = "small", srcs = [ "checksum_test.go", + "igmp_test.go", "ipv6_test.go", "ipversion_test.go", "tcp_test.go", diff --git a/pkg/tcpip/header/igmp.go b/pkg/tcpip/header/igmp.go new file mode 100644 index 000000000..e0f5d46f4 --- /dev/null +++ b/pkg/tcpip/header/igmp.go @@ -0,0 +1,166 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// IGMP represents an IGMP header stored in a byte array. +type IGMP []byte + +// IGMP implements `Transport`. +var _ Transport = (*IGMP)(nil) + +const ( + // IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes, + // as per RFC 2236, Section 2, Page 2. + IGMPMinimumSize = 8 + + // IGMPQueryMinimumSize is the minimum size of a valid Membership Query + // Message in bytes, as per RFC 2236, Section 2, Page 2. + IGMPQueryMinimumSize = 8 + + // IGMPReportMinimumSize is the minimum size of a valid Report Message in + // bytes, as per RFC 2236, Section 2, Page 2. + IGMPReportMinimumSize = 8 + + // IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message + // in bytes, as per RFC 2236, Section 2, Page 2. + IGMPLeaveMessageMinimumSize = 8 + + // IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page + // 3. + IGMPTTL = 1 + + // igmpTypeOffset defines the offset of the type field in an IGMP message. + igmpTypeOffset = 0 + + // igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an + // IGMP message. + igmpMaxRespTimeOffset = 1 + + // igmpChecksumOffset defines the offset of the checksum field in an IGMP + // message. + igmpChecksumOffset = 2 + + // igmpGroupAddressOffset defines the offset of the Group Address field in an + // IGMP message. + igmpGroupAddressOffset = 4 + + // IGMPProtocolNumber is IGMP's transport protocol number. + IGMPProtocolNumber tcpip.TransportProtocolNumber = 2 +) + +// IGMPType is the IGMP type field as per RFC 2236. +type IGMPType byte + +// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2. +// Descriptions below come from there. +const ( + // IGMPMembershipQuery indicates that the message type is Membership Query. + // "There are two sub-types of Membership Query messages: + // - General Query, used to learn which groups have members on an + // attached network. + // - Group-Specific Query, used to learn if a particular group + // has any members on an attached network. + // These two messages are differentiated by the Group Address, as + // described in section 1.4 ." + IGMPMembershipQuery IGMPType = 0x11 + // IGMPv1MembershipReport indicates that the message is a Membership Report + // generated by a host using the IGMPv1 protocol: "an additional type of + // message, for backwards-compatibility with IGMPv1" + IGMPv1MembershipReport IGMPType = 0x12 + // IGMPv2MembershipReport indicates that the Message type is a Membership + // Report generated by a host using the IGMPv2 protocol. + IGMPv2MembershipReport IGMPType = 0x16 + // IGMPLeaveGroup indicates that the message type is a Leave Group + // notification message. + IGMPLeaveGroup IGMPType = 0x17 +) + +// Type is the IGMP type field. +func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) } + +// SetType sets the IGMP type field. +func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) } + +// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership +// Query messages, in other cases it is set to 0 by the sender and ignored by +// the receiver. +func (b IGMP) MaxRespTime() byte { return b[igmpMaxRespTimeOffset] } + +// SetMaxRespTime sets the MaxRespTimeField. +func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m } + +// Checksum is the IGMP checksum field. +func (b IGMP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[igmpChecksumOffset:]) +} + +// SetChecksum sets the IGMP checksum field. +func (b IGMP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum) +} + +// GroupAddress gets the Group Address field. +func (b IGMP) GroupAddress() tcpip.Address { + return tcpip.Address(b[igmpGroupAddressOffset:][:IPv4AddressSize]) +} + +// SetGroupAddress sets the Group Address field. +func (b IGMP) SetGroupAddress(address tcpip.Address) { + if n := copy(b[igmpGroupAddressOffset:], address); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize)) + } +} + +// SourcePort implements Transport.SourcePort. +func (IGMP) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (IGMP) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (IGMP) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (IGMP) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (IGMP) Payload() []byte { + return nil +} + +// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP +// header. +func IGMPCalculateChecksum(h IGMP) uint16 { + // The header contains a checksum itself, set it aside to avoid checksumming + // the checksum and replace it afterwards. + existingXsum := h.Checksum() + h.SetChecksum(0) + xsum := ^Checksum(h, 0) + h.SetChecksum(existingXsum) + return xsum +} diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go new file mode 100644 index 000000000..66e872880 --- /dev/null +++ b/pkg/tcpip/header/igmp_test.go @@ -0,0 +1,101 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// TestIGMPHeader tests the functions within header.igmp +func TestIGMPHeader(t *testing.T) { + b := []byte{ + 0x11, // IGMP Type, Membership Query + 0xF0, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.MaxRespTime(), byte(0xF0); got != want { + t.Errorf("got igmpHeader.MaxRespTime() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) + } + + igmpType := header.IGMPv2MembershipReport + igmpHeader.SetType(igmpType) + if got := igmpHeader.Type(); got != igmpType { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType) + } + if got := header.IGMPType(b[0]); got != igmpType { + t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType) + } + + respTime := byte(0x02) + igmpHeader.SetMaxRespTime(respTime) + if got := igmpHeader.MaxRespTime(); got != respTime { + t.Errorf("got igmpHeader.MaxRespTime() = %x, want = %x", got, respTime) + } + + checksum := uint16(0x0102) + igmpHeader.SetChecksum(checksum) + if got := igmpHeader.Checksum(); got != checksum { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) + } + + groupAddress := tcpip.Address("\x04\x03\x02\x01") + igmpHeader.SetGroupAddress(groupAddress) + if got := igmpHeader.GroupAddress(); got != groupAddress { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) + } +} + +// TestIGMPChecksum ensures that the checksum calculator produces the expected +// checksum. +func TestIGMPChecksum(t *testing.T) { + b := []byte{ + 0x11, // IGMP Type, Membership Query + 0xF0, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + // Calculate the initial checksum after setting the checksum temporarily to 0 + // to avoid checksumming the checksum. + initialChecksum := igmpHeader.Checksum() + igmpHeader.SetChecksum(0) + checksum := ^header.Checksum(b, 0) + igmpHeader.SetChecksum(initialChecksum) + + if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum { + t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum) + } +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 91fe7b6a5..5fddd2af6 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -157,6 +157,9 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + // IPv4AllRoutersGroup is a multicast address for all routers. + IPv4AllRoutersGroup tcpip.Address = "\xe0\x00\x00\x02" + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP // packet that every IPv4 capable host must be able to // process/reassemble. diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 6252614ec..68b1ea1cd 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -6,6 +6,7 @@ go_library( name = "ipv4", srcs = [ "icmp.go", + "igmp.go", "ipv4.go", ], visibility = ["//visibility:public"], @@ -24,7 +25,10 @@ go_library( go_test( name = "ipv4_test", size = "small", - srcs = ["ipv4_test.go"], + srcs = [ + "igmp_test.go", + "ipv4_test.go", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go new file mode 100644 index 000000000..e1de58f73 --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -0,0 +1,398 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "fmt" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // igmpV1PresentDefault is the initial state for igmpV1Present in the + // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is + // the initial state." + igmpV1PresentDefault = false + + // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18 + // See note on igmpState.igmpV1Present for more detail. + v1RouterPresentTimeout = 400 * time.Second + + // v1MaxRespTimeTenthSec from RFC 2236 Section 4, Page 5. "The IGMPv1 router + // will send General Queries with the Max Response Time set to 0. This MUST + // be interpreted as a value of 100 (10 seconds)." + v1MaxRespTimeTenthSec = 100 + + // UnsolicitedReportIntervalMaxTenthSec from RFC 2236 Section 8.10, Page 19. + // As all IGMP delay timers are set to a random value between 0 and the + // interval, this is technically a maximum. + UnsolicitedReportIntervalMaxTenthSec = 100 +) + +// igmpState is the per-interface IGMP state. +// +// igmpState.init() MUST be called after creating an IGMP state. +type igmpState struct { + // The IPv4 endpoint this igmpState is for. + ep *endpoint + + mu struct { + sync.RWMutex + + // memberships contains the map of host groups to their state, timer, and + // flag info. + memberships map[tcpip.Address]membershipInfo + + // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from + // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 + // Membership Reports in response to its Queries, and will not pay + // attention to Version 2 Membership Reports. Therefore, a state variable + // MUST be kept for each interface, describing whether the multicast + // Querier on that interface is running IGMPv1 or IGMPv2. This variable + // MUST be based upon whether or not an IGMPv1 query was heard in the last + // [Version 1 Router Present Timeout] seconds" + igmpV1Present bool + + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job + } +} + +// membershipInfo holds the IGMPv2 state for a particular multicast address. +type membershipInfo struct { + // state contains the current IGMP state for this member. + state hostState + + // lastToSendReport is true if this was "the last host to send a report from + // this group." + // RFC 2236, Section 6, Page 9. This is used to track whether or not there + // are other hosts on this subnet that belong to this group - RFC 2236 + // Section 3, Page 5. + lastToSendReport bool + + // delayedReportJob is used to delay sending responses to IGMP messages in + // order to reduce duplicate reports from multiple hosts on the interface. + // Must not be nil. + delayedReportJob *tcpip.Job +} + +type hostState int + +// From RFC 2236, Section 6, Page 7. +const ( + // "'Non-Member' state, when the host does not belong to the group on + // the interface. This is the initial state for all memberships on + // all network interfaces; it requires no storage in the host." + _ hostState = iota + + // delayingMember is the "'Delaying Member' state, when the host belongs to + // the group on the interface and has a report delay timer running for that + // membership." + delayingMember + + // idleMember is the "Idle Member" state, when the host belongs to the group + // on the interface and does not have a report delay timer running for that + // membership. + idleMember +) + +// init sets up an igmpState struct, and is required to be called before using +// a new igmpState. +func (igmp *igmpState) init(ep *endpoint) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + igmp.ep = ep + igmp.mu.memberships = make(map[tcpip.Address]membershipInfo) + igmp.mu.igmpV1Present = igmpV1PresentDefault + igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.mu.igmpV1Present = false + }) +} + +func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { + stats := igmp.ep.protocol.stack.Stats() + received := stats.IGMP.PacketsReceived + headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + h := header.IGMP(headerView) + + // Temporarily reset the checksum field to 0 in order to calculate the proper + // checksum. + wantChecksum := h.Checksum() + h.SetChecksum(0) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(wantChecksum) + + if gotChecksum != wantChecksum { + received.ChecksumErrors.Increment() + return + } + + switch h.Type() { + case header.IGMPMembershipQuery: + received.MembershipQuery.Increment() + if len(headerView) < header.IGMPQueryMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) + case header.IGMPv1MembershipReport: + received.V1MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPv2MembershipReport: + received.V2MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPLeaveGroup: + received.LeaveGroup.Increment() + // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or + // Report, are ignored in all states" + + default: + // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should + // be silently ignored. New message types may be used by newer versions of + // IGMP, by multicast routing protocols, or other uses." + received.Unrecognized.Increment() + } +} + +func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime byte) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + + // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero + // then change the state to note that an IGMPv1 router is present and + // schedule the query received Job. + if maxRespTime == 0 { + igmp.mu.igmpV1Job.Cancel() + igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout) + igmp.mu.igmpV1Present = true + maxRespTime = v1MaxRespTimeTenthSec + } + + // IPv4Any is the General Query Address. + if groupAddress == header.IPv4Any { + for membershipAddress, info := range igmp.mu.memberships { + igmp.setDelayTimerForAddressRLocked(membershipAddress, &info, maxRespTime) + igmp.mu.memberships[membershipAddress] = info + } + } else if info, ok := igmp.mu.memberships[groupAddress]; ok { + igmp.setDelayTimerForAddressRLocked(groupAddress, &info, maxRespTime) + igmp.mu.memberships[groupAddress] = info + } +} + +// setDelayTimerForAddressRLocked modifies the passed info only and does not +// modify IGMP state directly. +// +// Precondition: igmp.mu MUST be read locked. +func (igmp *igmpState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *membershipInfo, maxRespTime byte) { + if info.state == delayingMember { + // As per RFC 2236 Section 3, page 3: "If a timer for the group is already + // running, it is reset to the random value only if the requested Max + // Response Time is less than the remaining value of the running timer. + // TODO: Reset the timer if time remaining is greater than maxRespTime. + return + } + info.state = delayingMember + info.delayedReportJob.Cancel() + info.delayedReportJob.Schedule(igmp.calculateDelayTimerDuration(maxRespTime)) +} + +func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + + // As per RFC 2236 Section 3, pages 3-4: "If the host receives another host's + // Report (version 1 or 2) while it has a timer running, it stops its timer + // for the specified group and does not send a Report" + if info, ok := igmp.mu.memberships[groupAddress]; ok { + info.delayedReportJob.Cancel() + info.lastToSendReport = false + igmp.mu.memberships[groupAddress] = info + } +} + +// writePacket assembles and sends an IGMP packet with the provided fields, +// incrementing the provided stat counter on success. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) { + igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) + igmpData.SetType(igmpType) + igmpData.SetGroupAddress(groupAddress) + igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()), + Data: buffer.View(igmpData).ToVectorisedView(), + }) + + // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, + // rather we should select an appropriate local address. + r := stack.Route{ + LocalAddress: header.IPv4Any, + RemoteAddress: destAddress, + } + igmp.ep.addIPHeader(&r, pkt, stack.NetworkHeaderParams{ + Protocol: header.IGMPProtocolNumber, + TTL: header.IGMPTTL, + TOS: stack.DefaultTOS, + }) + + // TODO(b/162198658): set the ROUTER_ALERT option when sending Host + // Membership Reports. + sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + sent.Dropped.Increment() + } else { + switch igmpType { + case header.IGMPv1MembershipReport: + sent.V1MembershipReport.Increment() + case header.IGMPv2MembershipReport: + sent.V2MembershipReport.Increment() + case header.IGMPLeaveGroup: + sent.LeaveGroup.Increment() + default: + panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) + } + } +} + +// sendReport sends a Host Membership Report in response to a query or after +// this host joins a new group on this interface. +// +// Precondition: igmp.mu MUST be locked. +func (igmp *igmpState) sendReportLocked(groupAddress tcpip.Address) { + igmpType := header.IGMPv2MembershipReport + if igmp.mu.igmpV1Present { + igmpType = header.IGMPv1MembershipReport + } + igmp.writePacket(groupAddress, groupAddress, igmpType) + + // Update the state of the membership for this group. If the group no longer + // exists, do nothing since this report must have been a race with a remove + // or is in the process of being added. + info, ok := igmp.mu.memberships[groupAddress] + if !ok { + return + } + info.state = idleMember + info.lastToSendReport = true + igmp.mu.memberships[groupAddress] = info +} + +// sendLeave sends a Leave Group report to the IPv4 All Routers Group. +// +// Precondition: igmp.mu MUST be read locked. +func (igmp *igmpState) sendLeaveRLocked(groupAddress tcpip.Address) { + // As per RFC 2236 Section 6, Page 8: "If the interface state says the + // Querier is running IGMPv1, this action SHOULD be skipped. If the flag + // saying we were the last host to report is cleared, this action MAY be + // skipped." + if igmp.mu.igmpV1Present || !igmp.mu.memberships[groupAddress].lastToSendReport { + return + } + + igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) +} + +// joinGroup handles adding a new group to the membership map, setting up the +// IGMP state for the group, and sending and scheduling the required +// messages. +// +// If the group already exists in the membership map, returns +// tcpip.ErrDuplicateAddress. +func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error { + igmp.mu.Lock() + defer igmp.mu.Unlock() + if _, ok := igmp.mu.memberships[groupAddress]; ok { + // The group already exists in the membership map. + return tcpip.ErrDuplicateAddress + } + + info := membershipInfo{ + // There isn't a Job scheduled currently, so it's just idle. + state: idleMember, + // Joining a group immediately sends a report. + lastToSendReport: true, + delayedReportJob: igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.sendReportLocked(groupAddress) + }), + } + + // As per RFC 2236 Section 3, Page 5: "When a host joins a multicast group, + // it should immediately transmit an unsolicited Version 2 Membership Report + // for that group" ... "it is recommended that it be repeated" + igmp.sendReportLocked(groupAddress) + igmp.setDelayTimerForAddressRLocked(groupAddress, &info, UnsolicitedReportIntervalMaxTenthSec) + igmp.mu.memberships[groupAddress] = info + + return nil +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Leave Group message +// if required. +// +// If the group does not exist in the membership map, this function will +// silently return. +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + info, ok := igmp.mu.memberships[groupAddress] + if !ok { + return + } + + // Clean up the state of the group before sending the leave message and + // removing it from the map. + info.delayedReportJob.Cancel() + info.state = idleMember + igmp.mu.memberships[groupAddress] = info + + igmp.sendLeaveRLocked(groupAddress) + delete(igmp.mu.memberships, groupAddress) +} + +// RFC 2236 Section 3, Page 3: The response time is set to a "random value... +// selected from the range (0, Max Response Time]" where Max Resp Time is given +// in units of 1/10 of a second. +func (igmp *igmpState) calculateDelayTimerDuration(maxRespTime byte) time.Duration { + maxRespTimeDuration := DecisecondToSecond(maxRespTime) + return time.Duration(igmp.ep.protocol.stack.Rand().Int63n(int64(maxRespTimeDuration))) +} + +// DecisecondToSecond converts a byte representing deci-seconds to a Duration +// type. This helper function exists because the IGMP stack sends and receives +// Max Response Times in deci-seconds. +func DecisecondToSecond(ds byte) time.Duration { + return time.Duration(ds) * time.Second / 10 +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go new file mode 100644 index 000000000..a0f37885a --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -0,0 +1,491 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + // endpointAddr = tcpip.Address("\x0a\x00\x00\x02") + multicastAddr = tcpip.Address("\xe0\x00\x00\x03") + nicID = 1 +) + +var ( + // unsolicitedReportIntervalMax is the maximum amount of time the NIC will + // wait before sending an unsolicited report after joining a multicast group. + unsolicitedReportIntervalMax = ipv4.DecisecondToSecond(ipv4.UnsolicitedReportIntervalMaxTenthSec) +) + +// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. Raises a t.Error if +// any field does not match. +func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.DstAddr(remoteAddress), + checker.IGMP( + checker.IGMPType(igmpType), + checker.IGMPMaxRespTime(maxRespTime), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + // Create an endpoint of queue size 1, since no more than 1 packets are ever + // queued in the tests in this file. + e := channel.New(1, 1280, linkAddr) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMPEnabled: igmpEnabled, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + return e, s, clock +} + +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: 1, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(igmpType) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestIgmpDisabled tests that IGMP is not enabled with a default +// stack.Options. This also tests that this NIC does not send the IGMP Join +// Group for the All Hosts group it automatically joins when created. +func TestIgmpDisabled(t *testing.T) { + e, s, _ := createStack(t, false) + + // This NIC will join the All Hosts group when created. Verify that does not + // send a report. + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got) + } + p, ok := e.Read() + if ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4.ProtocolNumber, %d, %s) = %s", nicID, multicastAddr, err) + } + + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got) + } + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } + + // Inject a General Membership Query, which is an IGMP Membership Query with + // a zeroed Group Address (IPv4Any) to verify that it does not reach the + // handler. + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, ipv4.UnsolicitedReportIntervalMaxTenthSec, header.IPv4Any) + + if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 0 { + t.Fatalf("got Membership Queries received = %d, want = 0", got) + } + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } +} + +// TestIgmpReceivesIGMPMessages tests that the IGMP stack increments packet +// counters when it receives properly formatted Membership Queries, Membership +// Reports, and LeaveGroup Messages sent to this address. Note: test includes +// IGMP header fields that are not explicitly tested in order to inject proper +// IGMP packets. +func TestIgmpReceivesIGMPMessages(t *testing.T) { + tests := []struct { + name string + headerType header.IGMPType + maxRespTime byte + groupAddress tcpip.Address + statCounter func(tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter + }{ + { + name: "General Membership Query", + headerType: header.IGMPMembershipQuery, + maxRespTime: ipv4.UnsolicitedReportIntervalMaxTenthSec, + groupAddress: header.IPv4Any, + statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { + return stats.MembershipQuery + }, + }, + { + name: "IGMPv1 Membership Report", + headerType: header.IGMPv1MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { + return stats.V1MembershipReport + }, + }, + { + name: "IGMPv2 Membership Report", + headerType: header.IGMPv2MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { + return stats.V2MembershipReport + }, + }, + { + name: "Leave Group", + headerType: header.IGMPLeaveGroup, + maxRespTime: 0, + groupAddress: header.IPv4AllRoutersGroup, + statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { + return stats.LeaveGroup + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, _ := createStack(t, true) + + createAndInjectIGMPPacket(e, test.headerType, test.maxRespTime, test.groupAddress) + + if got := test.statCounter(s.Stats().IGMP.PacketsReceived).Value(); got != 1 { + t.Fatalf("got %s received = %d, want = 1", test.name, got) + } + }) + } +} + +// TestIgmpJoinGroup tests that when explicitly joining a multicast group, the +// IGMP stack schedules and sends correct Membership Reports. +func TestIgmpJoinGroup(t *testing.T) { + e, s, clock := createStack(t, true) + + // Test joining a specific address explicitly and verify a Membership Report + // is sent immediately. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Verify the second Membership Report is sent after a random interval up to + // the unsolicitedReportIntervalMax. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(unsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) +} + +// TestIgmpLeaveGroup tests that when leaving a previously joined multicast +// group the IGMP enabled NIC sends the appropriate message. +func TestIgmpLeaveGroup(t *testing.T) { + e, s, clock := createStack(t, true) + + // Join a group so that it can be left, validate the immediate Membership + // Report is sent only to the multicast address joined. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Verify the second Membership Report is sent after a random interval up to + // the unsolicitedReportIntervalMax, and is sent to the multicast address + // being joined. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(unsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Now that there are no packets queued and none scheduled to be sent, leave + // the group. + if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // Observe the Leave Group Message to verify that the Leave Group message is + // sent to the All Routers group but that the message itself has the + // multicast address being left. + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected LeaveGroup") + } + if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 1 { + t.Fatalf("got LeaveGroup messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, header.IPv4AllRoutersGroup, header.IGMPLeaveGroup, 0, multicastAddr) +} + +// TestIgmpJoinLeaveGroup tests that when leaving a previously joined multicast +// group before the Unsolicited Report Interval cancels the second membership +// report. +func TestIgmpJoinLeaveGroup(t *testing.T) { + _, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // Verify that this NIC sent a Membership Report for only the group just + // joined. + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + + if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // Wait for the standard IGMP Unsolicited Report Interval duration before + // verifying that the unsolicited Membership Report was sent after leaving + // the group. + clock.Advance(unsolicitedReportIntervalMax) + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } +} + +// TestIgmpMembershipQueryReport tests the handling of both incoming IGMP +// Membership Queries and outgoing Membership Reports. +func TestIgmpMembershipQueryReport(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(unsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + + // Inject a General Membership Query, which is an IGMP Membership Query with + // a zeroed Group Address (IPv4Any) with the shortened Max Response Time. + const maxRespTimeDS = 10 + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, maxRespTimeDS, header.IPv4Any) + + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(ipv4.DecisecondToSecond(maxRespTimeDS)) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 3 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 3", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) +} + +// TestIgmpMultipleHosts tests the handling of IGMP Leave when we are not the +// most recent IGMP host to join a multicast network. +func TestIgmpMultipleHosts(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Inject another Host's Join Group message so that this host is not the + // latest to send the report. Set Max Response Time to 0 for Membership + // Reports. + createAndInjectIGMPPacket(e, header.IGMPv2MembershipReport, 0, multicastAddr) + + if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // Wait to be sure that no Leave Group messages were sent up to the max + // unsolicited report interval since it was not the last host to join this + // group. + clock.Advance(unsolicitedReportIntervalMax) + if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 0 { + t.Fatalf("got LeaveGroup messages sent = %d, want = 0", got) + } +} + +// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is +// present on the network. The IGMP stack will then send IGMPv1 Membership +// reports for backwards compatibility. +func TestIgmpV1Present(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // This NIC will send an IGMPv2 report immediately, before this test can get + // the IGMPv1 General Membership Query in. + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Inject an IGMPv1 General Membership Query which is identical to a standard + // membership query except the Max Response Time is set to 0, which will tell + // the stack that this is a router using IGMPv1. Send it to the all systems + // group which is the only group this host belongs to. + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems) + if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { + t.Fatalf("got Membership Queries received = %d, want = 1", got) + } + + // Before advancing the clock, verify that this host has not sent a + // V1MembershipReport yet. + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) + } + + // Verify the solicited Membership Report is sent. Now that this NIC has seen + // an IGMPv1 query, it should send an IGMPv1 Membership Report. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(unsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index ea8505692..7c759be9a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -72,6 +72,7 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol + igmp igmpState // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -94,6 +95,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa protocol: p, } e.mu.addressableEndpointState.Init(e) + e.igmp.init(e) return e } @@ -703,6 +705,13 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { e.handleICMP(pkt) return } + if p == header.IGMPProtocolNumber { + if e.protocol.options.IGMPEnabled { + e.igmp.handleIGMP(pkt) + } + // Nothing further to do with an IGMP packet, even if IGMP is not enabled. + return + } if opts := h.Options(); len(opts) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options @@ -834,14 +843,26 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + + joinedGroup, err := e.mu.addressableEndpointState.JoinGroup(addr) + if err == nil && joinedGroup && e.protocol.options.IGMPEnabled { + _ = e.igmp.joinGroup(addr) + } + + return joinedGroup, err } // LeaveGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + + leftGroup, err := e.mu.addressableEndpointState.LeaveGroup(addr) + if err == nil && leftGroup && e.protocol.options.IGMPEnabled { + e.igmp.leaveGroup(addr) + } + + return leftGroup, err } // IsInGroup implements stack.GroupAddressableEndpoint. @@ -874,6 +895,8 @@ type protocol struct { hashIV uint32 fragmentation *fragmentation.Fragmentation + + options Options } // Number returns the ipv4 protocol number. @@ -1007,8 +1030,15 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -// NewProtocol returns an IPv4 network protocol. -func NewProtocol(s *stack.Stack) stack.NetworkProtocol { +// Options holds options to configure a new protocol. +type Options struct { + // IGMPEnabled indicates whether incoming IGMP packets will be handled and if + // this endpoint will transmit IGMP packets on IGMP related events. + IGMPEnabled bool +} + +// NewProtocolWithOptions returns an IPv4 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -1018,14 +1048,22 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { } hashIV := r[buckets] - p := &protocol{ - stack: s, - ids: ids, - hashIV: hashIV, - defaultTTL: DefaultTTL, + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + options: opts, + } + p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + return p } - p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) - return p +} + +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) } func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c5d45ac6a..a2d234e7d 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1893,7 +1893,6 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - // TODO: notify network of subscription via igmp protocol. s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ad7a5b22f..3c7c5c0a8 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1434,6 +1434,60 @@ type ICMPStats struct { V6PacketsReceived ICMPv6ReceivedPacketStats } +// IGMPPacketStats enumerates counts for all IGMP packet types. +type IGMPPacketStats struct { + // MembershipQuery is the total number of Membership Query messages counted. + MembershipQuery *StatCounter + + // V1MembershipReport is the total number of Version 1 Membership Report + // messages counted. + V1MembershipReport *StatCounter + + // V2MembershipReport is the total number of Version 2 Membership Report + // messages counted. + V2MembershipReport *StatCounter + + // LeaveGroup is the total number of Leave Group messages counted. + LeaveGroup *StatCounter +} + +// IGMPSentPacketStats collects outbound IGMP-specific stats. +type IGMPSentPacketStats struct { + IGMPPacketStats + + // Dropped is the total number of IGMP packets dropped due to link layer + // errors. + Dropped *StatCounter +} + +// IGMPReceivedPacketStats collects inbound IGMP-specific stats. +type IGMPReceivedPacketStats struct { + IGMPPacketStats + + // Invalid is the total number of IGMP packets received that IGMP could not + // parse. + Invalid *StatCounter + + // ChecksumErrors is the total number of IGMP packets dropped due to bad + // checksums. + ChecksumErrors *StatCounter + + // Unrecognized is the total number of unrecognized messages counted, these + // are silently ignored for forward-compatibilty. + Unrecognized *StatCounter +} + +// IGMPStats colelcts IGMP-specific stats. +type IGMPStats struct { + // IGMPSentPacketStats contains counts of sent packets by IGMP packet type + // and a single count of invalid packets received. + PacketsSent IGMPSentPacketStats + + // IGMPReceivedPacketStats contains counts of received packets by IGMP packet + // type and a single count of invalid packets received. + PacketsReceived IGMPReceivedPacketStats +} + // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // PacketsReceived is the total number of IP packets received from the @@ -1639,6 +1693,9 @@ type Stats struct { // ICMP breaks out ICMP-specific stats (both v4 and v6). ICMP ICMPStats + // IGMP breaks out IGMP-specific stats. + IGMP IGMPStats + // IP breaks out IP-specific stats (both v4 and v6). IP IPStats diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 801cf8e0e..64563a8ba 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2797,7 +2797,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first - // land at the Dispatcher which then can either delivery using the + // land at the Dispatcher which then can either deliver using the // worker go routine or directly do the invoke the tcp processing inline // based on the state of the endpoint. } -- cgit v1.2.3