diff options
-rw-r--r-- | pkg/tcpip/checker/checker.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/header/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/header/igmp.go | 166 | ||||
-rw-r--r-- | pkg/tcpip/header/igmp_test.go | 101 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 398 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp_test.go | 491 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 60 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 57 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 |
12 files changed, 1341 insertions, 14 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 81f762e10..1c82c2c3b 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1227,3 +1227,71 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker { ndpOptions(t, rs.Options(), opts) } } + +// IGMP checks the validity and properties of the given IGMP packet. It is +// expected to be used in conjunction with other IGMP transport checkers for +// specific properties. +func IGMP(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) + } + + igmp := header.IGMP(last.Payload()) + for _, f := range checkers { + f(t, igmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// IGMPType creates a checker that checks the IGMP Type field. +func IGMPType(want header.IGMPType) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.Type(); got != want { + t.Errorf("got igmp.Type() = %d, want = %d", got, want) + } + } +} + +// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. +func IGMPMaxRespTime(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.MaxRespTime(); got != want { + t.Errorf("got igmp.MaxRespTime() = %d, want = %d", got, want) + } + } +} + +// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. +func IGMPGroupAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.GroupAddress(); got != want { + t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index ca6cbe41a..144093c3a 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -11,6 +11,7 @@ go_library( "gue.go", "icmpv4.go", "icmpv6.go", + "igmp.go", "interfaces.go", "ipv4.go", "ipv6.go", @@ -40,6 +41,7 @@ go_test( size = "small", srcs = [ "checksum_test.go", + "igmp_test.go", "ipv6_test.go", "ipversion_test.go", "tcp_test.go", diff --git a/pkg/tcpip/header/igmp.go b/pkg/tcpip/header/igmp.go new file mode 100644 index 000000000..e0f5d46f4 --- /dev/null +++ b/pkg/tcpip/header/igmp.go @@ -0,0 +1,166 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// IGMP represents an IGMP header stored in a byte array. +type IGMP []byte + +// IGMP implements `Transport`. +var _ Transport = (*IGMP)(nil) + +const ( + // IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes, + // as per RFC 2236, Section 2, Page 2. + IGMPMinimumSize = 8 + + // IGMPQueryMinimumSize is the minimum size of a valid Membership Query + // Message in bytes, as per RFC 2236, Section 2, Page 2. + IGMPQueryMinimumSize = 8 + + // IGMPReportMinimumSize is the minimum size of a valid Report Message in + // bytes, as per RFC 2236, Section 2, Page 2. + IGMPReportMinimumSize = 8 + + // IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message + // in bytes, as per RFC 2236, Section 2, Page 2. + IGMPLeaveMessageMinimumSize = 8 + + // IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page + // 3. + IGMPTTL = 1 + + // igmpTypeOffset defines the offset of the type field in an IGMP message. + igmpTypeOffset = 0 + + // igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an + // IGMP message. + igmpMaxRespTimeOffset = 1 + + // igmpChecksumOffset defines the offset of the checksum field in an IGMP + // message. + igmpChecksumOffset = 2 + + // igmpGroupAddressOffset defines the offset of the Group Address field in an + // IGMP message. + igmpGroupAddressOffset = 4 + + // IGMPProtocolNumber is IGMP's transport protocol number. + IGMPProtocolNumber tcpip.TransportProtocolNumber = 2 +) + +// IGMPType is the IGMP type field as per RFC 2236. +type IGMPType byte + +// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2. +// Descriptions below come from there. +const ( + // IGMPMembershipQuery indicates that the message type is Membership Query. + // "There are two sub-types of Membership Query messages: + // - General Query, used to learn which groups have members on an + // attached network. + // - Group-Specific Query, used to learn if a particular group + // has any members on an attached network. + // These two messages are differentiated by the Group Address, as + // described in section 1.4 ." + IGMPMembershipQuery IGMPType = 0x11 + // IGMPv1MembershipReport indicates that the message is a Membership Report + // generated by a host using the IGMPv1 protocol: "an additional type of + // message, for backwards-compatibility with IGMPv1" + IGMPv1MembershipReport IGMPType = 0x12 + // IGMPv2MembershipReport indicates that the Message type is a Membership + // Report generated by a host using the IGMPv2 protocol. + IGMPv2MembershipReport IGMPType = 0x16 + // IGMPLeaveGroup indicates that the message type is a Leave Group + // notification message. + IGMPLeaveGroup IGMPType = 0x17 +) + +// Type is the IGMP type field. +func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) } + +// SetType sets the IGMP type field. +func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) } + +// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership +// Query messages, in other cases it is set to 0 by the sender and ignored by +// the receiver. +func (b IGMP) MaxRespTime() byte { return b[igmpMaxRespTimeOffset] } + +// SetMaxRespTime sets the MaxRespTimeField. +func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m } + +// Checksum is the IGMP checksum field. +func (b IGMP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[igmpChecksumOffset:]) +} + +// SetChecksum sets the IGMP checksum field. +func (b IGMP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum) +} + +// GroupAddress gets the Group Address field. +func (b IGMP) GroupAddress() tcpip.Address { + return tcpip.Address(b[igmpGroupAddressOffset:][:IPv4AddressSize]) +} + +// SetGroupAddress sets the Group Address field. +func (b IGMP) SetGroupAddress(address tcpip.Address) { + if n := copy(b[igmpGroupAddressOffset:], address); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize)) + } +} + +// SourcePort implements Transport.SourcePort. +func (IGMP) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (IGMP) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (IGMP) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (IGMP) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (IGMP) Payload() []byte { + return nil +} + +// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP +// header. +func IGMPCalculateChecksum(h IGMP) uint16 { + // The header contains a checksum itself, set it aside to avoid checksumming + // the checksum and replace it afterwards. + existingXsum := h.Checksum() + h.SetChecksum(0) + xsum := ^Checksum(h, 0) + h.SetChecksum(existingXsum) + return xsum +} diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go new file mode 100644 index 000000000..66e872880 --- /dev/null +++ b/pkg/tcpip/header/igmp_test.go @@ -0,0 +1,101 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// TestIGMPHeader tests the functions within header.igmp +func TestIGMPHeader(t *testing.T) { + b := []byte{ + 0x11, // IGMP Type, Membership Query + 0xF0, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.MaxRespTime(), byte(0xF0); got != want { + t.Errorf("got igmpHeader.MaxRespTime() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) + } + + igmpType := header.IGMPv2MembershipReport + igmpHeader.SetType(igmpType) + if got := igmpHeader.Type(); got != igmpType { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType) + } + if got := header.IGMPType(b[0]); got != igmpType { + t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType) + } + + respTime := byte(0x02) + igmpHeader.SetMaxRespTime(respTime) + if got := igmpHeader.MaxRespTime(); got != respTime { + t.Errorf("got igmpHeader.MaxRespTime() = %x, want = %x", got, respTime) + } + + checksum := uint16(0x0102) + igmpHeader.SetChecksum(checksum) + if got := igmpHeader.Checksum(); got != checksum { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) + } + + groupAddress := tcpip.Address("\x04\x03\x02\x01") + igmpHeader.SetGroupAddress(groupAddress) + if got := igmpHeader.GroupAddress(); got != groupAddress { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) + } +} + +// TestIGMPChecksum ensures that the checksum calculator produces the expected +// checksum. +func TestIGMPChecksum(t *testing.T) { + b := []byte{ + 0x11, // IGMP Type, Membership Query + 0xF0, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + // Calculate the initial checksum after setting the checksum temporarily to 0 + // to avoid checksumming the checksum. + initialChecksum := igmpHeader.Checksum() + igmpHeader.SetChecksum(0) + checksum := ^header.Checksum(b, 0) + igmpHeader.SetChecksum(initialChecksum) + + if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum { + t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum) + } +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 91fe7b6a5..5fddd2af6 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -157,6 +157,9 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + // IPv4AllRoutersGroup is a multicast address for all routers. + IPv4AllRoutersGroup tcpip.Address = "\xe0\x00\x00\x02" + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP // packet that every IPv4 capable host must be able to // process/reassemble. diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 6252614ec..68b1ea1cd 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -6,6 +6,7 @@ go_library( name = "ipv4", srcs = [ "icmp.go", + "igmp.go", "ipv4.go", ], visibility = ["//visibility:public"], @@ -24,7 +25,10 @@ go_library( go_test( name = "ipv4_test", size = "small", - srcs = ["ipv4_test.go"], + srcs = [ + "igmp_test.go", + "ipv4_test.go", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go new file mode 100644 index 000000000..e1de58f73 --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -0,0 +1,398 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "fmt" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // igmpV1PresentDefault is the initial state for igmpV1Present in the + // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is + // the initial state." + igmpV1PresentDefault = false + + // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18 + // See note on igmpState.igmpV1Present for more detail. + v1RouterPresentTimeout = 400 * time.Second + + // v1MaxRespTimeTenthSec from RFC 2236 Section 4, Page 5. "The IGMPv1 router + // will send General Queries with the Max Response Time set to 0. This MUST + // be interpreted as a value of 100 (10 seconds)." + v1MaxRespTimeTenthSec = 100 + + // UnsolicitedReportIntervalMaxTenthSec from RFC 2236 Section 8.10, Page 19. + // As all IGMP delay timers are set to a random value between 0 and the + // interval, this is technically a maximum. + UnsolicitedReportIntervalMaxTenthSec = 100 +) + +// igmpState is the per-interface IGMP state. +// +// igmpState.init() MUST be called after creating an IGMP state. +type igmpState struct { + // The IPv4 endpoint this igmpState is for. + ep *endpoint + + mu struct { + sync.RWMutex + + // memberships contains the map of host groups to their state, timer, and + // flag info. + memberships map[tcpip.Address]membershipInfo + + // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from + // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 + // Membership Reports in response to its Queries, and will not pay + // attention to Version 2 Membership Reports. Therefore, a state variable + // MUST be kept for each interface, describing whether the multicast + // Querier on that interface is running IGMPv1 or IGMPv2. This variable + // MUST be based upon whether or not an IGMPv1 query was heard in the last + // [Version 1 Router Present Timeout] seconds" + igmpV1Present bool + + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job + } +} + +// membershipInfo holds the IGMPv2 state for a particular multicast address. +type membershipInfo struct { + // state contains the current IGMP state for this member. + state hostState + + // lastToSendReport is true if this was "the last host to send a report from + // this group." + // RFC 2236, Section 6, Page 9. This is used to track whether or not there + // are other hosts on this subnet that belong to this group - RFC 2236 + // Section 3, Page 5. + lastToSendReport bool + + // delayedReportJob is used to delay sending responses to IGMP messages in + // order to reduce duplicate reports from multiple hosts on the interface. + // Must not be nil. + delayedReportJob *tcpip.Job +} + +type hostState int + +// From RFC 2236, Section 6, Page 7. +const ( + // "'Non-Member' state, when the host does not belong to the group on + // the interface. This is the initial state for all memberships on + // all network interfaces; it requires no storage in the host." + _ hostState = iota + + // delayingMember is the "'Delaying Member' state, when the host belongs to + // the group on the interface and has a report delay timer running for that + // membership." + delayingMember + + // idleMember is the "Idle Member" state, when the host belongs to the group + // on the interface and does not have a report delay timer running for that + // membership. + idleMember +) + +// init sets up an igmpState struct, and is required to be called before using +// a new igmpState. +func (igmp *igmpState) init(ep *endpoint) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + igmp.ep = ep + igmp.mu.memberships = make(map[tcpip.Address]membershipInfo) + igmp.mu.igmpV1Present = igmpV1PresentDefault + igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.mu.igmpV1Present = false + }) +} + +func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { + stats := igmp.ep.protocol.stack.Stats() + received := stats.IGMP.PacketsReceived + headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + h := header.IGMP(headerView) + + // Temporarily reset the checksum field to 0 in order to calculate the proper + // checksum. + wantChecksum := h.Checksum() + h.SetChecksum(0) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(wantChecksum) + + if gotChecksum != wantChecksum { + received.ChecksumErrors.Increment() + return + } + + switch h.Type() { + case header.IGMPMembershipQuery: + received.MembershipQuery.Increment() + if len(headerView) < header.IGMPQueryMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) + case header.IGMPv1MembershipReport: + received.V1MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPv2MembershipReport: + received.V2MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPLeaveGroup: + received.LeaveGroup.Increment() + // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or + // Report, are ignored in all states" + + default: + // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should + // be silently ignored. New message types may be used by newer versions of + // IGMP, by multicast routing protocols, or other uses." + received.Unrecognized.Increment() + } +} + +func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime byte) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + + // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero + // then change the state to note that an IGMPv1 router is present and + // schedule the query received Job. + if maxRespTime == 0 { + igmp.mu.igmpV1Job.Cancel() + igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout) + igmp.mu.igmpV1Present = true + maxRespTime = v1MaxRespTimeTenthSec + } + + // IPv4Any is the General Query Address. + if groupAddress == header.IPv4Any { + for membershipAddress, info := range igmp.mu.memberships { + igmp.setDelayTimerForAddressRLocked(membershipAddress, &info, maxRespTime) + igmp.mu.memberships[membershipAddress] = info + } + } else if info, ok := igmp.mu.memberships[groupAddress]; ok { + igmp.setDelayTimerForAddressRLocked(groupAddress, &info, maxRespTime) + igmp.mu.memberships[groupAddress] = info + } +} + +// setDelayTimerForAddressRLocked modifies the passed info only and does not +// modify IGMP state directly. +// +// Precondition: igmp.mu MUST be read locked. +func (igmp *igmpState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *membershipInfo, maxRespTime byte) { + if info.state == delayingMember { + // As per RFC 2236 Section 3, page 3: "If a timer for the group is already + // running, it is reset to the random value only if the requested Max + // Response Time is less than the remaining value of the running timer. + // TODO: Reset the timer if time remaining is greater than maxRespTime. + return + } + info.state = delayingMember + info.delayedReportJob.Cancel() + info.delayedReportJob.Schedule(igmp.calculateDelayTimerDuration(maxRespTime)) +} + +func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + + // As per RFC 2236 Section 3, pages 3-4: "If the host receives another host's + // Report (version 1 or 2) while it has a timer running, it stops its timer + // for the specified group and does not send a Report" + if info, ok := igmp.mu.memberships[groupAddress]; ok { + info.delayedReportJob.Cancel() + info.lastToSendReport = false + igmp.mu.memberships[groupAddress] = info + } +} + +// writePacket assembles and sends an IGMP packet with the provided fields, +// incrementing the provided stat counter on success. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) { + igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) + igmpData.SetType(igmpType) + igmpData.SetGroupAddress(groupAddress) + igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()), + Data: buffer.View(igmpData).ToVectorisedView(), + }) + + // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, + // rather we should select an appropriate local address. + r := stack.Route{ + LocalAddress: header.IPv4Any, + RemoteAddress: destAddress, + } + igmp.ep.addIPHeader(&r, pkt, stack.NetworkHeaderParams{ + Protocol: header.IGMPProtocolNumber, + TTL: header.IGMPTTL, + TOS: stack.DefaultTOS, + }) + + // TODO(b/162198658): set the ROUTER_ALERT option when sending Host + // Membership Reports. + sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + sent.Dropped.Increment() + } else { + switch igmpType { + case header.IGMPv1MembershipReport: + sent.V1MembershipReport.Increment() + case header.IGMPv2MembershipReport: + sent.V2MembershipReport.Increment() + case header.IGMPLeaveGroup: + sent.LeaveGroup.Increment() + default: + panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) + } + } +} + +// sendReport sends a Host Membership Report in response to a query or after +// this host joins a new group on this interface. +// +// Precondition: igmp.mu MUST be locked. +func (igmp *igmpState) sendReportLocked(groupAddress tcpip.Address) { + igmpType := header.IGMPv2MembershipReport + if igmp.mu.igmpV1Present { + igmpType = header.IGMPv1MembershipReport + } + igmp.writePacket(groupAddress, groupAddress, igmpType) + + // Update the state of the membership for this group. If the group no longer + // exists, do nothing since this report must have been a race with a remove + // or is in the process of being added. + info, ok := igmp.mu.memberships[groupAddress] + if !ok { + return + } + info.state = idleMember + info.lastToSendReport = true + igmp.mu.memberships[groupAddress] = info +} + +// sendLeave sends a Leave Group report to the IPv4 All Routers Group. +// +// Precondition: igmp.mu MUST be read locked. +func (igmp *igmpState) sendLeaveRLocked(groupAddress tcpip.Address) { + // As per RFC 2236 Section 6, Page 8: "If the interface state says the + // Querier is running IGMPv1, this action SHOULD be skipped. If the flag + // saying we were the last host to report is cleared, this action MAY be + // skipped." + if igmp.mu.igmpV1Present || !igmp.mu.memberships[groupAddress].lastToSendReport { + return + } + + igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) +} + +// joinGroup handles adding a new group to the membership map, setting up the +// IGMP state for the group, and sending and scheduling the required +// messages. +// +// If the group already exists in the membership map, returns +// tcpip.ErrDuplicateAddress. +func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error { + igmp.mu.Lock() + defer igmp.mu.Unlock() + if _, ok := igmp.mu.memberships[groupAddress]; ok { + // The group already exists in the membership map. + return tcpip.ErrDuplicateAddress + } + + info := membershipInfo{ + // There isn't a Job scheduled currently, so it's just idle. + state: idleMember, + // Joining a group immediately sends a report. + lastToSendReport: true, + delayedReportJob: igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.sendReportLocked(groupAddress) + }), + } + + // As per RFC 2236 Section 3, Page 5: "When a host joins a multicast group, + // it should immediately transmit an unsolicited Version 2 Membership Report + // for that group" ... "it is recommended that it be repeated" + igmp.sendReportLocked(groupAddress) + igmp.setDelayTimerForAddressRLocked(groupAddress, &info, UnsolicitedReportIntervalMaxTenthSec) + igmp.mu.memberships[groupAddress] = info + + return nil +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Leave Group message +// if required. +// +// If the group does not exist in the membership map, this function will +// silently return. +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + info, ok := igmp.mu.memberships[groupAddress] + if !ok { + return + } + + // Clean up the state of the group before sending the leave message and + // removing it from the map. + info.delayedReportJob.Cancel() + info.state = idleMember + igmp.mu.memberships[groupAddress] = info + + igmp.sendLeaveRLocked(groupAddress) + delete(igmp.mu.memberships, groupAddress) +} + +// RFC 2236 Section 3, Page 3: The response time is set to a "random value... +// selected from the range (0, Max Response Time]" where Max Resp Time is given +// in units of 1/10 of a second. +func (igmp *igmpState) calculateDelayTimerDuration(maxRespTime byte) time.Duration { + maxRespTimeDuration := DecisecondToSecond(maxRespTime) + return time.Duration(igmp.ep.protocol.stack.Rand().Int63n(int64(maxRespTimeDuration))) +} + +// DecisecondToSecond converts a byte representing deci-seconds to a Duration +// type. This helper function exists because the IGMP stack sends and receives +// Max Response Times in deci-seconds. +func DecisecondToSecond(ds byte) time.Duration { + return time.Duration(ds) * time.Second / 10 +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go new file mode 100644 index 000000000..a0f37885a --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -0,0 +1,491 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + // endpointAddr = tcpip.Address("\x0a\x00\x00\x02") + multicastAddr = tcpip.Address("\xe0\x00\x00\x03") + nicID = 1 +) + +var ( + // unsolicitedReportIntervalMax is the maximum amount of time the NIC will + // wait before sending an unsolicited report after joining a multicast group. + unsolicitedReportIntervalMax = ipv4.DecisecondToSecond(ipv4.UnsolicitedReportIntervalMaxTenthSec) +) + +// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. Raises a t.Error if +// any field does not match. +func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.DstAddr(remoteAddress), + checker.IGMP( + checker.IGMPType(igmpType), + checker.IGMPMaxRespTime(maxRespTime), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + // Create an endpoint of queue size 1, since no more than 1 packets are ever + // queued in the tests in this file. + e := channel.New(1, 1280, linkAddr) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMPEnabled: igmpEnabled, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + return e, s, clock +} + +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: 1, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(igmpType) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestIgmpDisabled tests that IGMP is not enabled with a default +// stack.Options. This also tests that this NIC does not send the IGMP Join +// Group for the All Hosts group it automatically joins when created. +func TestIgmpDisabled(t *testing.T) { + e, s, _ := createStack(t, false) + + // This NIC will join the All Hosts group when created. Verify that does not + // send a report. + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got) + } + p, ok := e.Read() + if ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4.ProtocolNumber, %d, %s) = %s", nicID, multicastAddr, err) + } + + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got) + } + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } + + // Inject a General Membership Query, which is an IGMP Membership Query with + // a zeroed Group Address (IPv4Any) to verify that it does not reach the + // handler. + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, ipv4.UnsolicitedReportIntervalMaxTenthSec, header.IPv4Any) + + if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 0 { + t.Fatalf("got Membership Queries received = %d, want = 0", got) + } + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } +} + +// TestIgmpReceivesIGMPMessages tests that the IGMP stack increments packet +// counters when it receives properly formatted Membership Queries, Membership +// Reports, and LeaveGroup Messages sent to this address. Note: test includes +// IGMP header fields that are not explicitly tested in order to inject proper +// IGMP packets. +func TestIgmpReceivesIGMPMessages(t *testing.T) { + tests := []struct { + name string + headerType header.IGMPType + maxRespTime byte + groupAddress tcpip.Address + statCounter func(tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter + }{ + { + name: "General Membership Query", + headerType: header.IGMPMembershipQuery, + maxRespTime: ipv4.UnsolicitedReportIntervalMaxTenthSec, + groupAddress: header.IPv4Any, + statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { + return stats.MembershipQuery + }, + }, + { + name: "IGMPv1 Membership Report", + headerType: header.IGMPv1MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { + return stats.V1MembershipReport + }, + }, + { + name: "IGMPv2 Membership Report", + headerType: header.IGMPv2MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { + return stats.V2MembershipReport + }, + }, + { + name: "Leave Group", + headerType: header.IGMPLeaveGroup, + maxRespTime: 0, + groupAddress: header.IPv4AllRoutersGroup, + statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { + return stats.LeaveGroup + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, _ := createStack(t, true) + + createAndInjectIGMPPacket(e, test.headerType, test.maxRespTime, test.groupAddress) + + if got := test.statCounter(s.Stats().IGMP.PacketsReceived).Value(); got != 1 { + t.Fatalf("got %s received = %d, want = 1", test.name, got) + } + }) + } +} + +// TestIgmpJoinGroup tests that when explicitly joining a multicast group, the +// IGMP stack schedules and sends correct Membership Reports. +func TestIgmpJoinGroup(t *testing.T) { + e, s, clock := createStack(t, true) + + // Test joining a specific address explicitly and verify a Membership Report + // is sent immediately. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Verify the second Membership Report is sent after a random interval up to + // the unsolicitedReportIntervalMax. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(unsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) +} + +// TestIgmpLeaveGroup tests that when leaving a previously joined multicast +// group the IGMP enabled NIC sends the appropriate message. +func TestIgmpLeaveGroup(t *testing.T) { + e, s, clock := createStack(t, true) + + // Join a group so that it can be left, validate the immediate Membership + // Report is sent only to the multicast address joined. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Verify the second Membership Report is sent after a random interval up to + // the unsolicitedReportIntervalMax, and is sent to the multicast address + // being joined. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(unsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Now that there are no packets queued and none scheduled to be sent, leave + // the group. + if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // Observe the Leave Group Message to verify that the Leave Group message is + // sent to the All Routers group but that the message itself has the + // multicast address being left. + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected LeaveGroup") + } + if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 1 { + t.Fatalf("got LeaveGroup messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, header.IPv4AllRoutersGroup, header.IGMPLeaveGroup, 0, multicastAddr) +} + +// TestIgmpJoinLeaveGroup tests that when leaving a previously joined multicast +// group before the Unsolicited Report Interval cancels the second membership +// report. +func TestIgmpJoinLeaveGroup(t *testing.T) { + _, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // Verify that this NIC sent a Membership Report for only the group just + // joined. + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + + if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // Wait for the standard IGMP Unsolicited Report Interval duration before + // verifying that the unsolicited Membership Report was sent after leaving + // the group. + clock.Advance(unsolicitedReportIntervalMax) + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } +} + +// TestIgmpMembershipQueryReport tests the handling of both incoming IGMP +// Membership Queries and outgoing Membership Reports. +func TestIgmpMembershipQueryReport(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(unsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + + // Inject a General Membership Query, which is an IGMP Membership Query with + // a zeroed Group Address (IPv4Any) with the shortened Max Response Time. + const maxRespTimeDS = 10 + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, maxRespTimeDS, header.IPv4Any) + + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(ipv4.DecisecondToSecond(maxRespTimeDS)) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 3 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 3", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) +} + +// TestIgmpMultipleHosts tests the handling of IGMP Leave when we are not the +// most recent IGMP host to join a multicast network. +func TestIgmpMultipleHosts(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Inject another Host's Join Group message so that this host is not the + // latest to send the report. Set Max Response Time to 0 for Membership + // Reports. + createAndInjectIGMPPacket(e, header.IGMPv2MembershipReport, 0, multicastAddr) + + if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // Wait to be sure that no Leave Group messages were sent up to the max + // unsolicited report interval since it was not the last host to join this + // group. + clock.Advance(unsolicitedReportIntervalMax) + if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 0 { + t.Fatalf("got LeaveGroup messages sent = %d, want = 0", got) + } +} + +// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is +// present on the network. The IGMP stack will then send IGMPv1 Membership +// reports for backwards compatibility. +func TestIgmpV1Present(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // This NIC will send an IGMPv2 report immediately, before this test can get + // the IGMPv1 General Membership Query in. + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Inject an IGMPv1 General Membership Query which is identical to a standard + // membership query except the Max Response Time is set to 0, which will tell + // the stack that this is a router using IGMPv1. Send it to the all systems + // group which is the only group this host belongs to. + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems) + if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { + t.Fatalf("got Membership Queries received = %d, want = 1", got) + } + + // Before advancing the clock, verify that this host has not sent a + // V1MembershipReport yet. + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) + } + + // Verify the solicited Membership Report is sent. Now that this NIC has seen + // an IGMPv1 query, it should send an IGMPv1 Membership Report. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(unsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index ea8505692..7c759be9a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -72,6 +72,7 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol + igmp igmpState // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -94,6 +95,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa protocol: p, } e.mu.addressableEndpointState.Init(e) + e.igmp.init(e) return e } @@ -703,6 +705,13 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { e.handleICMP(pkt) return } + if p == header.IGMPProtocolNumber { + if e.protocol.options.IGMPEnabled { + e.igmp.handleIGMP(pkt) + } + // Nothing further to do with an IGMP packet, even if IGMP is not enabled. + return + } if opts := h.Options(); len(opts) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options @@ -834,14 +843,26 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + + joinedGroup, err := e.mu.addressableEndpointState.JoinGroup(addr) + if err == nil && joinedGroup && e.protocol.options.IGMPEnabled { + _ = e.igmp.joinGroup(addr) + } + + return joinedGroup, err } // LeaveGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + + leftGroup, err := e.mu.addressableEndpointState.LeaveGroup(addr) + if err == nil && leftGroup && e.protocol.options.IGMPEnabled { + e.igmp.leaveGroup(addr) + } + + return leftGroup, err } // IsInGroup implements stack.GroupAddressableEndpoint. @@ -874,6 +895,8 @@ type protocol struct { hashIV uint32 fragmentation *fragmentation.Fragmentation + + options Options } // Number returns the ipv4 protocol number. @@ -1007,8 +1030,15 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -// NewProtocol returns an IPv4 network protocol. -func NewProtocol(s *stack.Stack) stack.NetworkProtocol { +// Options holds options to configure a new protocol. +type Options struct { + // IGMPEnabled indicates whether incoming IGMP packets will be handled and if + // this endpoint will transmit IGMP packets on IGMP related events. + IGMPEnabled bool +} + +// NewProtocolWithOptions returns an IPv4 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -1018,14 +1048,22 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { } hashIV := r[buckets] - p := &protocol{ - stack: s, - ids: ids, - hashIV: hashIV, - defaultTTL: DefaultTTL, + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + options: opts, + } + p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + return p } - p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) - return p +} + +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) } func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c5d45ac6a..a2d234e7d 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1893,7 +1893,6 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - // TODO: notify network of subscription via igmp protocol. s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ad7a5b22f..3c7c5c0a8 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1434,6 +1434,60 @@ type ICMPStats struct { V6PacketsReceived ICMPv6ReceivedPacketStats } +// IGMPPacketStats enumerates counts for all IGMP packet types. +type IGMPPacketStats struct { + // MembershipQuery is the total number of Membership Query messages counted. + MembershipQuery *StatCounter + + // V1MembershipReport is the total number of Version 1 Membership Report + // messages counted. + V1MembershipReport *StatCounter + + // V2MembershipReport is the total number of Version 2 Membership Report + // messages counted. + V2MembershipReport *StatCounter + + // LeaveGroup is the total number of Leave Group messages counted. + LeaveGroup *StatCounter +} + +// IGMPSentPacketStats collects outbound IGMP-specific stats. +type IGMPSentPacketStats struct { + IGMPPacketStats + + // Dropped is the total number of IGMP packets dropped due to link layer + // errors. + Dropped *StatCounter +} + +// IGMPReceivedPacketStats collects inbound IGMP-specific stats. +type IGMPReceivedPacketStats struct { + IGMPPacketStats + + // Invalid is the total number of IGMP packets received that IGMP could not + // parse. + Invalid *StatCounter + + // ChecksumErrors is the total number of IGMP packets dropped due to bad + // checksums. + ChecksumErrors *StatCounter + + // Unrecognized is the total number of unrecognized messages counted, these + // are silently ignored for forward-compatibilty. + Unrecognized *StatCounter +} + +// IGMPStats colelcts IGMP-specific stats. +type IGMPStats struct { + // IGMPSentPacketStats contains counts of sent packets by IGMP packet type + // and a single count of invalid packets received. + PacketsSent IGMPSentPacketStats + + // IGMPReceivedPacketStats contains counts of received packets by IGMP packet + // type and a single count of invalid packets received. + PacketsReceived IGMPReceivedPacketStats +} + // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // PacketsReceived is the total number of IP packets received from the @@ -1639,6 +1693,9 @@ type Stats struct { // ICMP breaks out ICMP-specific stats (both v4 and v6). ICMP ICMPStats + // IGMP breaks out IGMP-specific stats. + IGMP IGMPStats + // IP breaks out IP-specific stats (both v4 and v6). IP IPStats diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 801cf8e0e..64563a8ba 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2797,7 +2797,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first - // land at the Dispatcher which then can either delivery using the + // land at the Dispatcher which then can either deliver using the // worker go routine or directly do the invoke the tcp processing inline // based on the state of the endpoint. } |