diff options
Diffstat (limited to 'pkg/tcpip/network')
26 files changed, 3884 insertions, 722 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index b38aff0b8..9ebf31b78 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -7,12 +7,14 @@ go_test( size = "small", srcs = [ "ip_test.go", + "multicast_group_test.go", ], deps = [ "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index f462524c9..0fb373612 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -319,9 +319,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { copy(h.HardwareAddressSender(), test.senderLinkAddr) copy(h.ProtocolAddressSender(), test.senderAddr) copy(h.ProtocolAddressTarget(), test.targetAddr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) if !test.isValid { // No packets should be sent after receiving an invalid ARP request. @@ -442,9 +442,9 @@ func (*testInterface) Promiscuous() bool { func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + NetProto: protocol, } + r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -557,8 +557,8 @@ func TestLinkAddressRequest(t *testing.T) { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) } rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index c75ca7d71..d31296a41 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -46,9 +46,13 @@ const ( ) var ( - // ErrInvalidArgs indicates to the caller that that an invalid argument was + // ErrInvalidArgs indicates to the caller that an invalid argument was // provided. ErrInvalidArgs = errors.New("invalid args") + + // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps + // with another one. + ErrFragmentOverlap = errors.New("overlapping fragments") ) // FragmentID is the identifier for a fragment. diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 19f4920b3..04072d966 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -26,9 +26,9 @@ import ( ) type hole struct { - first uint16 - last uint16 - deleted bool + first uint16 + last uint16 + filled bool } type reassembler struct { @@ -38,7 +38,7 @@ type reassembler struct { proto uint8 mu sync.Mutex holes []hole - deleted int + filled int heap fragHeap done bool creationTime int64 @@ -53,44 +53,86 @@ func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ - first: 0, - last: math.MaxUint16, - deleted: false}) + first: 0, + last: math.MaxUint16, + filled: false, + }) return r } -// updateHoles updates the list of holes for an incoming fragment and -// returns true iff the fragment filled at least part of an existing hole. -func (r *reassembler) updateHoles(first, last uint16, more bool) bool { - used := false +// updateHoles updates the list of holes for an incoming fragment. It returns +// true if the fragment fits, it is not a duplicate and it does not overlap with +// another fragment. +// +// For IPv6, overlaps with an existing fragment are explicitly forbidden by +// RFC 8200 section 4.5: +// If any of the fragments being reassembled overlap with any other fragments +// being reassembled for the same packet, reassembly of that packet must be +// abandoned and all the fragments that have been received for that packet +// must be discarded, and no ICMP error messages should be sent. +// +// It is not explicitly forbidden for IPv4, but to keep parity with Linux we +// disallow it as well: +// https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 +func (r *reassembler) updateHoles(first, last uint16, more bool) (bool, error) { for i := range r.holes { - if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first { + currentHole := &r.holes[i] + + if currentHole.filled || last < currentHole.first || currentHole.last < first { continue } - used = true - r.deleted++ - r.holes[i].deleted = true - if first > r.holes[i].first { - r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false}) + + if first < currentHole.first || currentHole.last < last { + // Incoming fragment only partially fits in the free hole. + return false, ErrFragmentOverlap + } + + r.filled++ + if first > currentHole.first { + r.holes = append(r.holes, hole{ + first: currentHole.first, + last: first - 1, + filled: false, + }) + } + if last < currentHole.last && more { + r.holes = append(r.holes, hole{ + first: last + 1, + last: currentHole.last, + filled: false, + }) } - if last < r.holes[i].last && more { - r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false}) + // Update the current hole to precisely match the incoming fragment. + r.holes[i] = hole{ + first: first, + last: last, + filled: true, } + return true, nil } - return used + + // Incoming fragment is a duplicate/subset, or its offset comes after the end + // of the reassembled payload. + return false, nil } func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() - consumed := 0 if r.done { // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, consumed, nil + return buffer.VectorisedView{}, 0, false, 0, nil } - if r.updateHoles(first, last, more) { + + used, err := r.updateHoles(first, last, more) + if err != nil { + return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err) + } + + var consumed int + if used { // For IPv6, it is possible to have different Protocol values between // fragments of a packet (because, unlike IPv4, the Protocol is not used to // identify a fragment). In this case, only the Protocol of the first @@ -109,13 +151,14 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s consumed = vv.Size() r.size += consumed } - // Check if all the holes have been deleted and we are ready to reassamble. - if r.deleted < len(r.holes) { + + // Check if all the holes have been filled and we are ready to reassemble. + if r.filled < len(r.holes) { return buffer.VectorisedView{}, 0, false, consumed, nil } res, err := r.heap.reassemble() if err != nil { - return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err) + return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err) } return res, r.proto, true, consumed, nil } diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index a0a04a027..cee3063b1 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -16,92 +16,124 @@ package fragmentation import ( "math" - "reflect" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/faketime" ) -type updateHolesInput struct { - first uint16 - last uint16 - more bool +type updateHolesParams struct { + first uint16 + last uint16 + more bool + wantUsed bool + wantError error } -var holesTestCases = []struct { - comment string - in []updateHolesInput - want []hole -}{ - { - comment: "No fragments. Expected holes: {[0 -> inf]}.", - in: []updateHolesInput{}, - want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, - }, - { - comment: "One fragment at beginning. Expected holes: {[2, inf]}.", - in: []updateHolesInput{{first: 0, last: 1, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: false}, +func TestUpdateHoles(t *testing.T) { + var tests = []struct { + name string + params []updateHolesParams + want []hole + }{ + { + name: "No fragments", + params: nil, + want: []hole{{first: 0, last: math.MaxUint16, filled: false}}, }, - }, - { - comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", - in: []updateHolesInput{{first: 1, last: 2, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - {first: 3, last: math.MaxUint16, deleted: false}, + { + name: "One fragment at beginning", + params: []updateHolesParams{{first: 0, last: 1, more: true, wantUsed: true, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true}, + {first: 2, last: math.MaxUint16, filled: false}, + }, }, - }, - { - comment: "One fragment at the end. Expected holes: {[0, 0]}.", - in: []updateHolesInput{{first: 1, last: 2, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, + { + name: "One fragment in the middle", + params: []updateHolesParams{{first: 1, last: 2, more: true, wantUsed: true, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true}, + {first: 0, last: 0, filled: false}, + {first: 3, last: math.MaxUint16, filled: false}, + }, }, - }, - { - comment: "One fragment completing a packet. Expected holes: {}.", - in: []updateHolesInput{{first: 0, last: 1, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, + { + name: "One fragment at the end", + params: []updateHolesParams{{first: 1, last: 2, more: false, wantUsed: true, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true}, + {first: 0, last: 0, filled: false}, + }, }, - }, - { - comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 1, more: true}, - {first: 2, last: 3, more: false}, + { + name: "One fragment completing a packet", + params: []updateHolesParams{{first: 0, last: 1, more: false, wantUsed: true, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true}, + }, }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: true}, + { + name: "Two fragments completing a packet", + params: []updateHolesParams{ + {first: 0, last: 1, more: true, wantUsed: true, wantError: nil}, + {first: 2, last: 3, more: false, wantUsed: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true}, + {first: 2, last: 3, filled: true}, + }, }, - }, - { - comment: "Two overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 2, more: true}, - {first: 2, last: 3, more: false}, + { + name: "Two fragments completing a packet with a duplicate", + params: []updateHolesParams{ + {first: 0, last: 1, more: true, wantUsed: true, wantError: nil}, + {first: 0, last: 1, more: true, wantUsed: false, wantError: nil}, + {first: 2, last: 3, more: false, wantUsed: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true}, + {first: 2, last: 3, filled: true}, + }, }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 3, last: math.MaxUint16, deleted: true}, + { + name: "Two overlapping fragments", + params: []updateHolesParams{ + {first: 0, last: 10, more: true, wantUsed: true, wantError: nil}, + {first: 5, last: 15, more: false, wantUsed: false, wantError: ErrFragmentOverlap}, + {first: 11, last: 15, more: false, wantUsed: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 10, filled: true}, + {first: 11, last: 15, filled: true}, + }, }, - }, -} + { + name: "Out of bounds fragment", + params: []updateHolesParams{ + {first: 0, last: 10, more: true, wantUsed: true, wantError: nil}, + {first: 11, last: 15, more: false, wantUsed: true, wantError: nil}, + {first: 16, last: 20, more: false, wantUsed: false, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 10, filled: true}, + {first: 11, last: 15, filled: true}, + }, + }, + } -func TestUpdateHoles(t *testing.T) { - for _, c := range holesTestCases { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - for _, i := range c.in { - r.updateHoles(i.first, i.last, i.more) - } - if !reflect.DeepEqual(r.holes, c.want) { - t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + for _, param := range test.params { + used, err := r.updateHoles(param.first, param.last, param.more) + if used != param.wantUsed || err != param.wantError { + t.Errorf("got r.updateHoles(%d, %d, %t) = (%t, %v), want = (%t, %v)", param.first, param.last, param.more, used, err, param.wantUsed, param.wantError) + } + } + if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD new file mode 100644 index 000000000..6ca200b48 --- /dev/null +++ b/pkg/tcpip/network/ip/BUILD @@ -0,0 +1,25 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ip", + srcs = ["generic_multicast_protocol.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + ], +) + +go_test( + name = "ip_test", + size = "small", + srcs = ["generic_multicast_protocol_test.go"], + deps = [ + ":ip", + "//pkg/tcpip", + "//pkg/tcpip/faketime", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go new file mode 100644 index 000000000..f14e2a88a --- /dev/null +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -0,0 +1,546 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ip holds IPv4/IPv6 common utilities. +package ip + +import ( + "fmt" + "math/rand" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// hostState is the state a host may be in for a multicast group. +type hostState int + +// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 +// (RFC 2710 section 5). Even though the states are generic across both IGMPv2 +// and MLDv1, IGMPv2 terminology will be used. +const ( + // nonMember is the "'Non-Member' state, when the host does not belong to the + // group on the interface. This is the initial state for all memberships on + // all network interfaces; it requires no storage in the host." + // + // 'Non-Listener' is the MLDv1 term used to describe this state. + // + // This state is used to keep track of groups that have been joined locally, + // but without advertising the membership to the network. + nonMember hostState = iota + + // delayingMember is the "'Delaying Member' state, when the host belongs to + // the group on the interface and has a report delay timer running for that + // membership." + // + // 'Delaying Listener' is the MLDv1 term used to describe this state. + delayingMember + + // idleMember is the "Idle Member" state, when the host belongs to the group + // on the interface and does not have a report delay timer running for that + // membership. + // + // 'Idle Listener' is the MLDv1 term used to describe this state. + idleMember +) + +// multicastGroupState holds the Generic Multicast Protocol state for a +// multicast group. +type multicastGroupState struct { + // joins is the number of times the group has been joined. + joins uint64 + + // state holds the host's state for the group. + state hostState + + // lastToSendReport is true if we sent the last report for the group. It is + // used to track whether there are other hosts on the subnet that are also + // members of the group. + // + // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page + // 8 for MLDv1. + lastToSendReport bool + + // delayedReportJob is used to delay sending responses to membership report + // messages in order to reduce duplicate reports from multiple hosts on the + // interface. + // + // Must not be nil. + delayedReportJob *tcpip.Job +} + +// GenericMulticastProtocolOptions holds options for the generic multicast +// protocol. +type GenericMulticastProtocolOptions struct { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled bool + + // Rand is the source of random numbers. + Rand *rand.Rand + + // Clock is the clock used to create timers. + Clock tcpip.Clock + + // Protocol is the implementation of the variant of multicast group protocol + // in use. + Protocol MulticastGroupProtocol + + // MaxUnsolicitedReportDelay is the maximum amount of time to wait between + // transmitting unsolicited reports. + // + // Unsolicited reports are transmitted when a group is newly joined. + MaxUnsolicitedReportDelay time.Duration + + // AllNodesAddress is a multicast address that all nodes on a network should + // be a member of. + // + // This address will not have the generic multicast protocol performed on it; + // it will be left in the non member/listener state, and packets will never + // be sent for it. + AllNodesAddress tcpip.Address +} + +// MulticastGroupProtocol is a multicast group protocol whose core state machine +// can be represented by GenericMulticastProtocolState. +type MulticastGroupProtocol interface { + // SendReport sends a multicast report for the specified group address. + SendReport(groupAddress tcpip.Address) *tcpip.Error + + // SendLeave sends a multicast leave for the specified group address. + SendLeave(groupAddress tcpip.Address) *tcpip.Error +} + +// GenericMulticastProtocolState is the per interface generic multicast protocol +// state. +// +// There is actually no protocol named "Generic Multicast Protocol". Instead, +// the term used to refer to a generic multicast protocol that applies to both +// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state +// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. +// +// GenericMulticastProtocolState.Init MUST be called before calling any of +// the methods on GenericMulticastProtocolState. +type GenericMulticastProtocolState struct { + opts GenericMulticastProtocolOptions + + mu struct { + sync.RWMutex + + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState + } +} + +// Init initializes the Generic Multicast Protocol state. +// +// maxUnsolicitedReportDelay is the maximum time between sending unsolicited +// reports after joining a group. +func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) { + g.mu.Lock() + defer g.mu.Unlock() + g.opts = opts + g.mu.memberships = make(map[tcpip.Address]multicastGroupState) +} + +// MakeAllNonMember transitions all groups to the non-member state. +// +// The groups will still be considered joined locally. +func (g *GenericMulticastProtocolState) MakeAllNonMember() { + if !g.opts.Enabled { + return + } + + g.mu.Lock() + defer g.mu.Unlock() + + for groupAddress, info := range g.mu.memberships { + g.transitionToNonMemberLocked(groupAddress, &info) + g.mu.memberships[groupAddress] = info + } +} + +// InitializeGroups initializes each group, as if they were newly joined but +// without affecting the groups' join count. +// +// Must only be called after calling MakeAllNonMember as a group should not be +// initialized while it is not in the non-member state. +func (g *GenericMulticastProtocolState) InitializeGroups() { + if !g.opts.Enabled { + return + } + + g.mu.Lock() + defer g.mu.Unlock() + + for groupAddress, info := range g.mu.memberships { + g.initializeNewMemberLocked(groupAddress, &info) + g.mu.memberships[groupAddress] = info + } +} + +// JoinGroup handles joining a new group. +// +// If dontInitialize is true, the group will be not be initialized and will be +// left in the non-member state - no packets will be sent for it until it is +// initialized via InitializeGroups. +func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) { + g.mu.Lock() + defer g.mu.Unlock() + + if info, ok := g.mu.memberships[groupAddress]; ok { + // The group has already been joined. + info.joins++ + g.mu.memberships[groupAddress] = info + return + } + + info := multicastGroupState{ + // Since we just joined the group, its count is 1. + joins: 1, + // The state will be updated below, if required. + state: nonMember, + lastToSendReport: false, + delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() { + info, ok := g.mu.memberships[groupAddress] + if !ok { + panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) + } + + info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil + info.state = idleMember + g.mu.memberships[groupAddress] = info + }), + } + + if !dontInitialize && g.opts.Enabled { + g.initializeNewMemberLocked(groupAddress, &info) + } + + g.mu.memberships[groupAddress] = info +} + +// IsLocallyJoined returns true if the group is locally joined. +func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool { + g.mu.RLock() + defer g.mu.RUnlock() + _, ok := g.mu.memberships[groupAddress] + return ok +} + +// LeaveGroup handles leaving the group. +// +// Returns false if the group is not currently joined. +func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool { + g.mu.Lock() + defer g.mu.Unlock() + + info, ok := g.mu.memberships[groupAddress] + if !ok { + return false + } + + if info.joins == 0 { + panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress)) + } + info.joins-- + if info.joins != 0 { + // If we still have outstanding joins, then do nothing further. + g.mu.memberships[groupAddress] = info + return true + } + + g.transitionToNonMemberLocked(groupAddress, &info) + delete(g.mu.memberships, groupAddress) + return true +} + +// HandleQuery handles a query message with the specified maximum response time. +// +// If the group address is unspecified, then reports will be scheduled for all +// joined groups. +// +// Report(s) will be scheduled to be sent after a random duration between 0 and +// the maximum response time. +func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) { + if !g.opts.Enabled { + return + } + + g.mu.Lock() + defer g.mu.Unlock() + + // As per RFC 2236 section 2.4 (for IGMPv2), + // + // In a Membership Query message, the group address field is set to zero + // when sending a General Query, and set to the group address being + // queried when sending a Group-Specific Query. + // + // As per RFC 2710 section 3.6 (for MLDv1), + // + // In a Query message, the Multicast Address field is set to zero when + // sending a General Query, and set to a specific IPv6 multicast address + // when sending a Multicast-Address-Specific Query. + if groupAddress.Unspecified() { + // This is a general query as the group address is unspecified. + for groupAddress, info := range g.mu.memberships { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.mu.memberships[groupAddress] = info + } + } else if info, ok := g.mu.memberships[groupAddress]; ok { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.mu.memberships[groupAddress] = info + } +} + +// HandleReport handles a report message. +// +// If the report is for a joined group, any active delayed report will be +// cancelled and the host state for the group transitions to idle. +func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) { + if !g.opts.Enabled { + return + } + + g.mu.Lock() + defer g.mu.Unlock() + + // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), + // + // If the host receives another host's Report (version 1 or 2) while it has + // a timer running, it stops its timer for the specified group and does not + // send a Report + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // If a node receives another node's Report from an interface for a + // multicast address while it has a timer running for that same address + // on that interface, it stops its timer and does not send a Report for + // that address, thus suppressing duplicate reports on the link. + if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember { + info.delayedReportJob.Cancel() + info.lastToSendReport = false + info.state = idleMember + g.mu.memberships[groupAddress] = info + } +} + +// initializeNewMemberLocked initializes a new group membership. +// +// Precondition: g.mu must be locked. +func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != nonMember { + panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) + } + + info.state = idleMember + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a host joins a multicast group, it should immediately transmit an + // unsolicited Version 2 Membership Report for that group" ... "it is + // recommended that it be repeated". + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node starts listening to a multicast address on an interface, + // it should immediately transmit an unsolicited Report for that address + // on that interface, in case it is the first listener on the link. To + // cover the possibility of the initial Report being lost or damaged, it + // is recommended that it be repeated once or twice after short delays + // [Unsolicited Report Interval]. + // + // TODO(gvisor.dev/issue/4901): Support a configurable number of initial + // unsolicited reports. + info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) +} + +// maybeSendLeave attempts to send a leave message. +func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { + if !g.opts.Enabled || !lastToSendReport { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // Okay to ignore the error here as if packet write failed, the multicast + // routers will eventually drop our membership anyways. If the interface is + // being disabled or removed, the generic multicast protocol's should be + // cleared eventually. + // + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a router receives a Report, it adds the group being reported to + // the list of multicast group memberships on the network on which it + // received the Report and sets the timer for the membership to the + // [Group Membership Interval]. Repeated Reports refresh the timer. If + // no Reports are received for a particular group before this timer has + // expired, the router assumes that the group has no local members and + // that it need not forward remotely-originated multicasts for that + // group onto the attached network. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // When a router receives a Report from a link, if the reported address + // is not already present in the router's list of multicast address + // having listeners on that link, the reported address is added to the + // list, its timer is set to [Multicast Listener Interval], and its + // appearance is made known to the router's multicast routing component. + // If a Report is received for a multicast address that is already + // present in the router's list, the timer for that address is reset to + // [Multicast Listener Interval]. If an address's timer expires, it is + // assumed that there are no longer any listeners for that address + // present on the link, so it is deleted from the list and its + // disappearance is made known to the multicast routing component. + // + // The requirement to send a leave message is also optional (it MAY be + // skipped): + // + // As per RFC 2236 section 6 page 8 (for IGMPv2), + // + // "send leave" for the group on the interface. If the interface + // state says the Querier is running IGMPv1, this action SHOULD be + // skipped. If the flag saying we were the last host to report is + // cleared, this action MAY be skipped. The Leave Message is sent to + // the ALL-ROUTERS group (224.0.0.2). + // + // As per RFC 2710 section 5 page 8 (for MLDv1), + // + // "send done" for the address on the interface. If the flag saying + // we were the last node to report is cleared, this action MAY be + // skipped. The Done message is sent to the link-scope all-routers + // address (FF02::2). + _ = g.opts.Protocol.SendLeave(groupAddress) +} + +// transitionToNonMemberLocked transitions the given multicast group the the +// non-member/listener state. +// +// Precondition: e.mu must be locked. +func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state == nonMember { + return + } + + info.delayedReportJob.Cancel() + g.maybeSendLeave(groupAddress, info.lastToSendReport) + info.lastToSendReport = false + info.state = nonMember +} + +// setDelayTimerForAddressRLocked sets timer to send a delay report. +// +// Precondition: g.mu MUST be read locked. +func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { + if info.state == nonMember { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // If a timer for the group is already unning, it is reset to the random + // value only if the requested Max Response Time is less than the remaining + // value of the running timer. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // If a timer for any address is already running, it is reset to the new + // random value only if the requested Maximum Response Delay is less than + // the remaining value of the running timer. + if info.state == delayingMember { + // TODO: Reset the timer if time remaining is greater than maxResponseTime. + return + } + info.state = delayingMember + info.delayedReportJob.Cancel() + info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) +} + +// calculateDelayTimerDuration returns a random time between (0, maxRespTime]. +func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration { + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // When a host receives a Group-Specific Query, it sets a delay timer to a + // random value selected from the range (0, Max Response Time]... + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node receives a Multicast-Address-Specific Query, if it is + // listening to the queried Multicast Address on the interface from + // which the Query was received, it sets a delay timer for that address + // to a random value selected from the range [0, Maximum Response Delay], + // as above. + if maxRespTime == 0 { + return 0 + } + return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime))) +} diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go new file mode 100644 index 000000000..670be30d4 --- /dev/null +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -0,0 +1,576 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "math/rand" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" +) + +const ( + addr1 = tcpip.Address("\x01") + addr2 = tcpip.Address("\x02") + addr3 = tcpip.Address("\x03") + addr4 = tcpip.Address("\x04") + + maxUnsolicitedReportDelay = time.Second +) + +var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) + +type mockMulticastGroupProtocol struct { + sendReportGroupAddrCount map[tcpip.Address]int + sendLeaveGroupAddrCount map[tcpip.Address]int +} + +func (m *mockMulticastGroupProtocol) init() { + m.sendReportGroupAddrCount = make(map[tcpip.Address]int) + m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +} + +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error { + m.sendReportGroupAddrCount[groupAddress]++ + return nil +} + +func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + m.sendLeaveGroupAddrCount[groupAddress]++ + return nil +} + +func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + sendReportGroupAddressesMap := make(map[tcpip.Address]int) + for _, a := range sendReportGroupAddresses { + sendReportGroupAddressesMap[a] = 1 + } + + sendLeaveGroupAddressesMap := make(map[tcpip.Address]int) + for _, a := range sendLeaveGroupAddresses { + sendLeaveGroupAddressesMap[a] = 1 + } + + diff := cmp.Diff(mockMulticastGroupProtocol{ + sendReportGroupAddrCount: sendReportGroupAddressesMap, + sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap, + }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{})) + mgp.init() + return diff +} + +func TestJoinGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendReports bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendReports: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendReports: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(0)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + // Joining a group should send a report immediately and another after + // a random interval between 0 and the maximum unsolicited report delay. + g.JoinGroup(test.addr, false /* dontInitialize */) + if test.shouldSendReports { + if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + clock.Advance(maxUnsolicitedReportDelay) + if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLeaveGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendMessages bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendMessages: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendMessages: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(1)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + g.JoinGroup(test.addr, false /* dontInitialize */) + if test.shouldSendMessages { + if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Leaving a group should send a leave report immediately and cancel any + // delayed reports. + if !g.LeaveGroup(test.addr) { + t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr) + } + if test.shouldSendMessages { + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleReport(t *testing.T) { + tests := []struct { + name string + reportAddr tcpip.Address + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + reportAddr: "", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + reportAddr: "\x00", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + reportAddr: addr1, + expectReportsFor: []tcpip.Address{addr2}, + }, + { + name: "Specified all-nodes", + reportAddr: addr3, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified other", + reportAddr: addr4, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(2)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + g.JoinGroup(addr1, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr2, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr3, false /* dontInitialize */) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report for a group we have a timer scheduled for should + // cancel our delayed report timer for the group. + g.HandleReport(test.reportAddr) + if len(test.expectReportsFor) != 0 { + clock.Advance(maxUnsolicitedReportDelay) + if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleQuery(t *testing.T) { + tests := []struct { + name string + queryAddr tcpip.Address + maxDelay time.Duration + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + queryAddr: "", + maxDelay: 0, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + queryAddr: "\x00", + maxDelay: 1, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + queryAddr: addr1, + maxDelay: 2, + expectReportsFor: []tcpip.Address{addr1}, + }, + { + name: "Specified all-nodes", + queryAddr: addr3, + maxDelay: 3, + expectReportsFor: nil, + }, + { + name: "Specified other", + queryAddr: addr4, + maxDelay: 4, + expectReportsFor: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + g.JoinGroup(addr1, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr2, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr3, false /* dontInitialize */) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(maxUnsolicitedReportDelay) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should make us schedule a new delayed report if it + // is a query directed at us or a general query. + g.HandleQuery(test.queryAddr, test.maxDelay) + if len(test.expectReportsFor) != 0 { + clock.Advance(test.maxDelay) + if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestJoinCount(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: time.Second, + }) + + // Set the join count to 2 for a group. + g.JoinGroup(addr1, false /* dontInitialize */) + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + // Only the first join should trigger a report to be sent. + if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr1, false /* dontInitialize */) + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Group should still be considered joined after leaving once. + if !g.LeaveGroup(addr1) { + t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + } + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + // A leave report should only be sent once the join count reaches 0. + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Leaving once more should actually remove us from the group. + if !g.LeaveGroup(addr1) { + t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + } + if g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Group should no longer be joined so we should not have anything to + // leave. + if g.LeaveGroup(addr1) { + t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1) + } + if g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestMakeAllNonMemberAndInitialize(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + g.JoinGroup(addr1, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr2, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr3, false /* dontInitialize */) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should send the leave reports for each but still consider them locally + // joined. + g.MakeAllNonMember() + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + for _, group := range []tcpip.Address{addr1, addr2, addr3} { + if !g.IsLocallyJoined(group) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group) + } + } + + // Should send the initial set of unsolcited reports. + g.InitializeGroups() + if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(maxUnsolicitedReportDelay) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +// TestGroupStateNonMember tests that groups do not send packets when in the +// non-member state, but are still considered locally joined. +func TestGroupStateNonMember(t *testing.T) { + tests := []struct { + name string + enabled bool + dontInitialize bool + }{ + { + name: "Disabled", + enabled: false, + dontInitialize: false, + }, + { + name: "Keep non-member", + enabled: true, + dontInitialize: true, + }, + { + name: "disabled and Keep non-member", + enabled: false, + dontInitialize: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: test.enabled, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + g.JoinGroup(addr1, test.dontInitialize) + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + g.JoinGroup(addr2, test.dontInitialize) + if !g.IsLocallyJoined(addr2) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + g.HandleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + if !g.LeaveGroup(addr2) { + t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2) + } + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + if g.IsLocallyJoined(addr2) { + t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index d49c44846..a314dd386 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -193,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu panic("not implemented") } -func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - return tcpip.ErrNotSupported -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*testObject) ARPHardwareType() header.ARPHardwareType { panic("not implemented") @@ -207,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net panic("not implemented") } -func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { +func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -223,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } -func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { +func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -554,7 +550,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -937,7 +933,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -1093,7 +1089,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { dataBuf := [dataLen]byte{1, 2, 3, 4} data := dataBuf[:] - ipv4Options := header.IPv4Options{0, 1, 0, 1} + ipv4Options := header.IPv4OptionsSerializer{ + &header.IPv4SerializableListEndOption{}, + &header.IPv4SerializableNOPOption{}, + &header.IPv4SerializableListEndOption{}, + &header.IPv4SerializableNOPOption{}, + } + + expectOptions := header.IPv4Options{ + byte(header.IPv4OptionListEndType), + byte(header.IPv4OptionNOPType), + byte(header.IPv4OptionListEndType), + byte(header.IPv4OptionNOPType), + } ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] @@ -1243,7 +1251,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding() + ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1266,7 +1274,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { netHdr := pkt.NetworkHeader() - hdrLen := header.IPv4MinimumSize + len(ipv4Options) + hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) if len(netHdr.View()) != hdrLen { t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) } @@ -1276,7 +1284,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { checker.DstAddr(remoteIPv4Addr), checker.IPv4HeaderLength(hdrLen), checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(ipv4Options), + checker.IPv4Options(expectOptions), checker.IPPayload(data), ) }, @@ -1288,7 +1296,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding())) + ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) ip.Encode(&header.IPv4Fields{ Protocol: transportProto, TTL: ipv4.DefaultTTL, @@ -1307,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { netHdr := pkt.NetworkHeader() - hdrLen := header.IPv4MinimumSize + len(ipv4Options) + hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) if len(netHdr.View()) != hdrLen { t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) } @@ -1317,7 +1325,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { checker.DstAddr(remoteIPv4Addr), checker.IPv4HeaderLength(hdrLen), checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(ipv4Options), + checker.IPv4Options(expectOptions), checker.IPPayload(data), ) }, diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 6252614ec..32f53f217 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -6,6 +6,7 @@ go_library( name = "ipv4", srcs = [ "icmp.go", + "igmp.go", "ipv4.go", ], visibility = ["//visibility:public"], @@ -17,6 +18,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -24,7 +26,10 @@ go_library( go_test( name = "ipv4_test", size = "small", - srcs = ["ipv4_test.go"], + srcs = [ + "igmp_test.go", + "ipv4_test.go", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 488945226..8e392f86c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -63,7 +63,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { stats := e.protocol.stack.Stats() - received := stats.ICMP.V4PacketsReceived + received := stats.ICMP.V4.PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. @@ -130,7 +130,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.Echo.Increment() - sent := stats.ICMP.V4PacketsSent + sent := stats.ICMP.V4.PacketsSent if !e.protocol.stack.AllowICMPMessage() { sent.RateLimited.Increment() return @@ -379,7 +379,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - sent := p.stack.Stats().ICMP.V4PacketsSent + sent := p.stack.Stats().ICMP.V4.PacketsSent if !p.stack.AllowICMPMessage() { sent.RateLimited.Increment() return nil diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go new file mode 100644 index 000000000..0134fadc0 --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -0,0 +1,323 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // igmpV1PresentDefault is the initial state for igmpV1Present in the + // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is + // the initial state." + igmpV1PresentDefault = 0 + + // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18 + // See note on igmpState.igmpV1Present for more detail. + v1RouterPresentTimeout = 400 * time.Second + + // v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router + // will send General Queries with the Max Response Time set to 0. This MUST + // be interpreted as a value of 100 (10 seconds)." + // + // Note that the Max Response Time field is a value in units of deciseconds. + v1MaxRespTime = 10 * time.Second + + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited IGMP reports. + // + // Obtained from RFC 2236 Section 8.10, Page 19. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// IGMPOptions holds options for IGMP. +type IGMPOptions struct { + // Enabled indicates whether IGMP will be performed. + // + // When enabled, IGMP may transmit IGMP report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // IGMP packets. + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*igmpState)(nil) + +// igmpState is the per-interface IGMP state. +// +// igmpState.init() MUST be called after creating an IGMP state. +type igmpState struct { + // The IPv4 endpoint this igmpState is for. + ep *endpoint + opts IGMPOptions + + // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from + // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 + // Membership Reports in response to its Queries, and will not pay + // attention to Version 2 Membership Reports. Therefore, a state variable + // MUST be kept for each interface, describing whether the multicast + // Querier on that interface is running IGMPv1 or IGMPv2. This variable + // MUST be based upon whether or not an IGMPv1 query was heard in the last + // [Version 1 Router Present Timeout] seconds". + // + // Must be accessed with atomic operations. Holds a value of 1 when true, 0 + // when false. + igmpV1Present uint32 + + mu struct { + sync.RWMutex + + genericMulticastProtocol ip.GenericMulticastProtocolState + + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job + } +} + +// SendReport implements ip.MulticastGroupProtocol. +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { + igmpType := header.IGMPv2MembershipReport + if igmp.v1Present() { + igmpType = header.IGMPv1MembershipReport + } + return igmp.writePacket(groupAddress, groupAddress, igmpType) +} + +// SendLeave implements ip.MulticastGroupProtocol. +func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + // As per RFC 2236 Section 6, Page 8: "If the interface state says the + // Querier is running IGMPv1, this action SHOULD be skipped. If the flag + // saying we were the last host to report is cleared, this action MAY be + // skipped." + if igmp.v1Present() { + return nil + } + return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) +} + +// init sets up an igmpState struct, and is required to be called before using +// a new igmpState. +func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + igmp.ep = ep + igmp.opts = opts + igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ + Enabled: opts.Enabled, + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: igmp, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv4AllSystems, + }) + igmp.igmpV1Present = igmpV1PresentDefault + igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.setV1Present(false) + }) +} + +func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { + stats := igmp.ep.protocol.stack.Stats() + received := stats.IGMP.PacketsReceived + headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + h := header.IGMP(headerView) + + // Temporarily reset the checksum field to 0 in order to calculate the proper + // checksum. + wantChecksum := h.Checksum() + h.SetChecksum(0) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(wantChecksum) + + if gotChecksum != wantChecksum { + received.ChecksumErrors.Increment() + return + } + + switch h.Type() { + case header.IGMPMembershipQuery: + received.MembershipQuery.Increment() + if len(headerView) < header.IGMPQueryMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) + case header.IGMPv1MembershipReport: + received.V1MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPv2MembershipReport: + received.V2MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPLeaveGroup: + received.LeaveGroup.Increment() + // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or + // Report, are ignored in all states" + + default: + // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should + // be silently ignored. New message types may be used by newer versions of + // IGMP, by multicast routing protocols, or other uses." + received.Unrecognized.Increment() + } +} + +func (igmp *igmpState) v1Present() bool { + return atomic.LoadUint32(&igmp.igmpV1Present) == 1 +} + +func (igmp *igmpState) setV1Present(v bool) { + if v { + atomic.StoreUint32(&igmp.igmpV1Present, 1) + } else { + atomic.StoreUint32(&igmp.igmpV1Present, 0) + } +} + +func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + + // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero + // then change the state to note that an IGMPv1 router is present and + // schedule the query received Job. + if maxRespTime == 0 && igmp.opts.Enabled { + igmp.mu.igmpV1Job.Cancel() + igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout) + igmp.setV1Present(true) + maxRespTime = v1MaxRespTime + } + + igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime) +} + +func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + igmp.mu.genericMulticastProtocol.HandleReport(groupAddress) +} + +// writePacket assembles and sends an IGMP packet with the provided fields, +// incrementing the provided stat counter on success. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error { + igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) + igmpData.SetType(igmpType) + igmpData.SetGroupAddress(groupAddress) + igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()), + Data: buffer.View(igmpData).ToVectorisedView(), + }) + + // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, + // rather we should select an appropriate local address. + localAddr := header.IPv4Any + igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.IGMPProtocolNumber, + TTL: header.IGMPTTL, + TOS: stack.DefaultTOS, + }, header.IPv4OptionsSerializer{ + &header.IPv4SerializableRouterAlertOption{}, + }) + + sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sent.Dropped.Increment() + return err + } + switch igmpType { + case header.IGMPv1MembershipReport: + sent.V1MembershipReport.Increment() + case header.IGMPv2MembershipReport: + sent.V2MembershipReport.Increment() + case header.IGMPLeaveGroup: + sent.LeaveGroup.Increment() + default: + panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) + } + return nil +} + +// joinGroup handles adding a new group to the membership map, setting up the +// IGMP state for the group, and sending and scheduling the required +// messages. +// +// If the group already exists in the membership map, returns +// tcpip.ErrDuplicateAddress. +func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { + igmp.mu.Lock() + defer igmp.mu.Unlock() + igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) +} + +// isInGroup returns true if the specified group has been joined locally. +func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { + igmp.mu.Lock() + defer igmp.mu.Unlock() + return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress) +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Leave Group message +// if required. +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + igmp.mu.Lock() + defer igmp.mu.Unlock() + + // LeaveGroup returns false only if the group was not joined. + if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) { + return nil + } + + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of IGMP, but remains +// joined locally. +func (igmp *igmpState) softLeaveAll() { + igmp.mu.Lock() + defer igmp.mu.Unlock() + igmp.mu.genericMulticastProtocol.MakeAllNonMember() +} + +// initializeAll attemps to initialize the IGMP state for each group that has +// been joined locally. +func (igmp *igmpState) initializeAll() { + igmp.mu.Lock() + defer igmp.mu.Unlock() + igmp.mu.genericMulticastProtocol.InitializeGroups() +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go new file mode 100644 index 000000000..5e139377b --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -0,0 +1,156 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + multicastAddr = tcpip.Address("\xe0\x00\x00\x03") + nicID = 1 +) + +// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. Raises a t.Error if +// any field does not match. +func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.DstAddr(remoteAddress), + // TTL for an IGMP message must be 1 as per RFC 2236 section 2. + checker.TTL(1), + checker.IPv4RouterAlert(), + checker.IGMP( + checker.IGMPType(igmpType), + checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + // Create an endpoint of queue size 1, since no more than 1 packets are ever + // queued in the tests in this file. + e := channel.New(1, 1280, linkAddr) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMP: ipv4.IGMPOptions{ + Enabled: igmpEnabled, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + return e, s, clock +} + +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: 1, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(igmpType) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is +// present on the network. The IGMP stack will then send IGMPv1 Membership +// reports for backwards compatibility. +func TestIgmpV1Present(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // This NIC will send an IGMPv2 report immediately, before this test can get + // the IGMPv1 General Membership Query in. + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Inject an IGMPv1 General Membership Query which is identical to a standard + // membership query except the Max Response Time is set to 0, which will tell + // the stack that this is a router using IGMPv1. Send it to the all systems + // group which is the only group this host belongs to. + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems) + if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { + t.Fatalf("got Membership Queries received = %d, want = 1", got) + } + + // Before advancing the clock, verify that this host has not sent a + // V1MembershipReport yet. + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) + } + + // Verify the solicited Membership Report is sent. Now that this NIC has seen + // an IGMPv1 query, it should send an IGMPv1 Membership Report. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1efe6297a..3076185cd 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -72,6 +72,7 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol + igmp igmpState // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -94,6 +95,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa protocol: p, } e.mu.addressableEndpointState.Init(e) + e.igmp.init(e, p.options.IGMP) return e } @@ -121,11 +123,22 @@ func (e *endpoint) Enable() *tcpip.Error { // We have no need for the address endpoint. ep.DecRef() + // Groups may have been joined while the endpoint was disabled, or the + // endpoint may have left groups from the perspective of IGMP when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + e.igmp.initializeAll() + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the // all-systems multicast group. - _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) - return err + if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err)) + } + + return nil } // Enabled implements stack.NetworkEndpoint. @@ -162,10 +175,14 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } + // Leave groups from the perspective of IGMP so that routers know that + // we are no longer interested in the group. + e.igmp.softLeaveAll() + // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) @@ -198,37 +215,34 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) { hdrLen := header.IPv4MinimumSize - var opts header.IPv4Options - if params.Options != nil { - var ok bool - if opts, ok = params.Options.(header.IPv4Options); !ok { - panic(fmt.Sprintf("want IPv4Options, got %T", params.Options)) - } - hdrLen += opts.SizeWithPadding() - if hdrLen > header.IPv4MaximumHeaderSize { - // Since we have no way to report an error we must either panic or create - // a packet which is different to what was requested. Choose panic as this - // would be a programming error that should be caught in testing. - panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize)) - } + var optLen int + if options != nil { + optLen = int(options.Length()) + } + hdrLen += optLen + if hdrLen > header.IPv4MaximumHeaderSize { + // Since we have no way to report an error we must either panic or create + // a packet which is different to what was requested. Choose panic as this + // would be a programming error that should be caught in testing. + panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize)) } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := uint16(pkt.Size()) // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. - id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ TotalLength: length, ID: uint16(id), TTL: params.TTL, TOS: params.TOS, Protocol: uint8(params.Protocol), - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - Options: opts, + SrcAddr: srcAddr, + DstAddr: dstAddr, + Options: options, }) ip.SetChecksum(^ip.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber @@ -259,7 +273,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) // iptables filtering. All packets that reach here are locally // generated. @@ -347,7 +361,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) @@ -461,7 +475,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } @@ -566,21 +580,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { stats.IP.MalformedPacketsReceived.Increment() return } - srcAddr := h.SourceAddress() - dstAddr := h.DestinationAddress() - - addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) - if addressEndpoint == nil { - if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - _ = e.forwardPacket(pkt) - return - } - subnet := addressEndpoint.AddressWithPrefix().Subnet() - addressEndpoint.DecRef() // There has been some confusion regarding verifying checksums. We need // just look for negative 0 (0xffff) as the checksum, as it's not possible to @@ -608,16 +607,42 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + // As per RFC 1122 section 3.2.1.3: // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { + if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { stats.IP.InvalidSourceAddressesReceived.Increment() return } + // Make sure the source address is not a subnet-local broadcast address. + if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + if subnet.IsBroadcast(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() + return + } + } - pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + // The destination address should be an address we own or a group we joined + // for us to receive the packet. Otherwise, attempt to forward the packet. + if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { + subnet := addressEndpoint.AddressWithPrefix().Subnet() + addressEndpoint.DecRef() + pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + } else if !e.IsInGroup(dstAddr) { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. @@ -692,6 +717,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { e.handleICMP(pkt) return } + if p == header.IGMPProtocolNumber { + e.igmp.handleIGMP(pkt) + return + } if opts := h.Options(); len(opts) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options @@ -770,28 +799,12 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo defer e.mu.Unlock() loopback := e.nic.IsLoopback() - addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { + return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool { subnet := addressEndpoint.Subnet() // IPv4 has a notion of a subnet broadcast address and considers the // loopback interface bound to an address's whole subnet (on linux). return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) - }) - if addressEndpoint != nil { - return addressEndpoint - } - - if !allowTemp { - return nil - } - - addr := localAddr.WithPrefix() - addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB) - if err != nil { - // AddAddress only returns an error if the address is already assigned, - // but we just checked above if the address exists so we expect no error. - panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err)) - } - return addressEndpoint + }, allowTemp, tempPEB) } // AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. @@ -816,28 +829,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV4MulticastAddress(addr) { - return false, tcpip.ErrBadAddress + return tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + e.igmp.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) @@ -863,6 +891,8 @@ type protocol struct { hashIV uint32 fragmentation *fragmentation.Fragmentation + + options Options } // Number returns the ipv4 protocol number. @@ -987,17 +1017,23 @@ func addressToUint32(addr tcpip.Address) uint32 { return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 } -// hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number and a 32-bit number to -// generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - a := addressToUint32(r.LocalAddress) - b := addressToUint32(r.RemoteAddress) +// hashRoute calculates a hash value for the given source/destination pair using +// the addresses, transport protocol number and a 32-bit number to generate the +// hash. +func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { + a := addressToUint32(srcAddr) + b := addressToUint32(dstAddr) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -// NewProtocol returns an IPv4 network protocol. -func NewProtocol(s *stack.Stack) stack.NetworkProtocol { +// Options holds options to configure a new protocol. +type Options struct { + // IGMP holds options for IGMP. + IGMP IGMPOptions +} + +// NewProtocolWithOptions returns an IPv4 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -1007,14 +1043,22 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { } hashIV := r[buckets] - p := &protocol{ - stack: s, - ids: ids, - hashIV: hashIV, - defaultTTL: DefaultTTL, + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + options: opts, + } + p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + return p } - p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) - return p +} + +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) } func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { @@ -1129,6 +1173,12 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres } pointer := tsOpt.Pointer() + // RFC 791 page 22 states: "The smallest legal value is 5." + // Since the pointer is 1 based, and the header is 4 bytes long the + // pointer must point beyond the header therefore 4 or less is bad. + if pointer <= header.IPv4OptionTimestampHdrLength { + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } // To simplify processing below, base further work on the array of timestamps // beyond the header, rather than on the whole option. Also to aid // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. @@ -1215,7 +1265,15 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength } - nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based. + pointer := rrOpt.Pointer() + // RFC 791 page 20 states: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // Since the pointer is 1 based, and the header is 3 bytes long the + // pointer must point beyond the header therefore 3 or less is bad. + if pointer <= header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } // RFC 791 page 21 says // If the route data area is already full (the pointer exceeds the @@ -1230,14 +1288,14 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // do this (as do most implementations). It is probable that the inclusion // of these words is a copy/paste error from the timestamp option where // there are two failure reasons given. - if nextSlot >= optlen { + if pointer > optlen { return 0, nil } // The data area isn't full but there isn't room for a new entry. // Either Length or Pointer could be bad. We must select Pointer for Linux - // compatibility, even if only the length is bad. - if nextSlot+header.IPv4AddressSize > optlen { + // compatibility, even if only the length is bad. NB. pointer is 1 based. + if pointer+header.IPv4AddressSize > optlen+1 { if false { // This is what we would do if we were not being Linux compatible. // Check for bad pointer or length value. Must be a multiple of 4 after diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 4e4e1f3b4..9e2d2cfd6 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -103,105 +103,6 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested -// fields when options are supplied. -func TestIPv4EncodeOptions(t *testing.T) { - tests := []struct { - name string - options header.IPv4Options - encodedOptions header.IPv4Options // reply should look like this - wantIHL int - }{ - { - name: "valid no options", - wantIHL: header.IPv4MinimumSize, - }, - { - name: "one byte options", - options: header.IPv4Options{1}, - encodedOptions: header.IPv4Options{1, 0, 0, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "two byte options", - options: header.IPv4Options{1, 1}, - encodedOptions: header.IPv4Options{1, 1, 0, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "three byte options", - options: header.IPv4Options{1, 1, 1}, - encodedOptions: header.IPv4Options{1, 1, 1, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "four byte options", - options: header.IPv4Options{1, 1, 1, 1}, - encodedOptions: header.IPv4Options{1, 1, 1, 1}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "five byte options", - options: header.IPv4Options{1, 1, 1, 1, 1}, - encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, - wantIHL: header.IPv4MinimumSize + 8, - }, - { - name: "thirty nine byte options", - options: header.IPv4Options{ - 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, - 25, 26, 27, 28, 29, 30, 31, 32, - 33, 34, 35, 36, 37, 38, 39, - }, - encodedOptions: header.IPv4Options{ - 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, - 25, 26, 27, 28, 29, 30, 31, 32, - 33, 34, 35, 36, 37, 38, 39, 0, - }, - wantIHL: header.IPv4MinimumSize + 40, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - paddedOptionLength := test.options.SizeWithPadding() - ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength) - hdr := buffer.NewPrependable(int(totalLen)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - // To check the padding works, poison the last byte of the options space. - if paddedOptionLength != len(test.options) { - ip.SetHeaderLength(uint8(ipHeaderLength)) - ip.Options()[paddedOptionLength-1] = 0xff - ip.SetHeaderLength(0) - } - ip.Encode(&header.IPv4Fields{ - Options: test.options, - }) - options := ip.Options() - wantOptions := test.encodedOptions - if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { - t.Errorf("got IHL of %d, want %d", got, want) - } - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(wantOptions) == 0 && len(options) == 0 { - return - } - - if diff := cmp.Diff(wantOptions, options); diff != "" { - t.Errorf("options mismatch (-want +got):\n%s", diff) - } - }) - } -} - func TestForwarding(t *testing.T) { const ( nicID1 = 1 @@ -453,14 +354,6 @@ func TestIPv4Sanity(t *testing.T) { replyOptions: header.IPv4Options{1, 1, 0, 0}, }, { - name: "Check option padding", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 1}, - replyOptions: header.IPv4Options{1, 1, 1, 0}, - }, - { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, maxTotalLength: ipv4.MaxTotalSize, @@ -583,7 +476,7 @@ func TestIPv4Sanity(t *testing.T) { 68, 7, 5, 0, // ^ ^ Linux points here which is wrong. // | Not a multiple of 4 - 1, 2, 3, + 1, 2, 3, 0, }, shouldFail: true, expectErrorICMP: true, @@ -662,6 +555,56 @@ func TestIPv4Sanity(t *testing.T) { }, }, { + // Timestamp pointer uses one based counting so 0 is invalid. + name: "timestamp pointer invalid", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, 0, 0x00, + // ^ 0 instead of 5 or more. + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Timestamp pointer cannot be less than 5. It must point past the header + // which is 4 bytes. (1 based counting) + name: "timestamp pointer too small by 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, + // ^ header is 4 bytes, so 4 should fail. + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "valid timestamp pointer", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, + // ^ header is 4 bytes, so 5 should succeed. + 0, 0, 0, 0, + }, + replyOptions: header.IPv4Options{ + 68, 8, 9, 0x00, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { // Needs 8 bytes for a type 1 timestamp but there are only 4 free. name: "bad timer element alignment", maxTotalLength: ipv4.MaxTotalSize, @@ -792,7 +735,61 @@ func TestIPv4Sanity(t *testing.T) { }, }, { - // Confirm linux bug for bug compatibility. + // Pointer uses one based counting so 0 is invalid. + name: "record route pointer zero", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, 0, // 3 byte header + 0, 0, 0, 0, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Pointer must be 4 or more as it must point past the 3 byte header + // using 1 based counting. 3 should fail. + name: "record route pointer too small by 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header + 0, 0, 0, 0, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Pointer must be 4 or more as it must point past the 3 byte header + // using 1 based counting. Check 4 passes. (Duplicates "single + // record route with room") + name: "valid record route pointer", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: header.IPv4Options{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm Linux bug for bug compatibility. // Linux returns slot 22 but the error is in slot 21. name: "multiple record route with not enough room", maxTotalLength: ipv4.MaxTotalSize, @@ -863,8 +860,10 @@ func TestIPv4Sanity(t *testing.T) { }, }) - paddedOptionLength := test.options.SizeWithPadding() - ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength + if len(test.options)%4 != 0 { + t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) + } + ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } @@ -883,11 +882,6 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } - // To check the padding works, poison the options space. - if paddedOptionLength != len(test.options) { - ip.SetHeaderLength(uint8(ipHeaderLength)) - ip.Options()[paddedOptionLength-1] = 0x01 - } ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, @@ -895,10 +889,19 @@ func TestIPv4Sanity(t *testing.T) { TTL: test.TTL, SrcAddr: remoteIPv4Addr, DstAddr: ipv4Addr.Address, - Options: test.options, }) if test.headerLength != 0 { ip.SetHeaderLength(test.headerLength) + } else { + // Set the calculated header length, since we may manually add options. + ip.SetHeaderLength(uint8(ipHeaderLength)) + } + if len(test.options) != 0 { + // Copy options manually. We do not use Encode for options so we can + // verify malformed options with handcrafted payloads. + if want, got := copy(ip.Options(), test.options), len(test.options); want != got { + t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) + } } ip.SetChecksum(0) ipHeaderChecksum := ip.CalculateChecksum() @@ -1003,7 +1006,7 @@ func TestIPv4Sanity(t *testing.T) { } // If the IP options change size then the packet will change size, so // some IP header fields will need to be adjusted for the checks. - sizeChange := len(test.replyOptions) - paddedOptionLength + sizeChange := len(test.replyOptions) - len(test.options) checker.IPv4(t, replyIPHeader, checker.IPv4HeaderLength(ipHeaderLength+sizeChange), @@ -2320,6 +2323,28 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, + { + name: "Two fragments with MF flag reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload4Addr1ToAddr2[:65512], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 65512, + payload: ipv4Payload4Addr1ToAddr2[65512:], + }, + }, + expectedPayloads: nil, + }, } for _, test := range tests { @@ -2513,7 +2538,7 @@ func TestWriteStats(t *testing.T) { test.setup(t, rt.Stack()) - nWritten, _ := writer.writePackets(&rt, pkts) + nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) @@ -2530,7 +2555,7 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) @@ -2644,8 +2669,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2687,8 +2712,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2736,8 +2761,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != arp.ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) } - if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress) } rep := header.ARP(p.Pkt.NetworkHeader().View()) if got := rep.Op(); got != header.ARPRequest { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 0ac24a6fb..5e75c8740 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -8,6 +8,7 @@ go_library( "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "mld.go", "ndp.go", ], visibility = ["//visibility:public"], @@ -19,6 +20,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -49,3 +51,16 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "ipv6_x_test", + size = "small", + srcs = ["mld_test.go"], + deps = [ + ":ipv6", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index beb8f562e..510276b8e 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -126,8 +126,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := e.protocol.stack.Stats().ICMP - sent := stats.V6PacketsSent - received := stats.V6PacketsReceived + sent := stats.V6.PacketsSent + received := stats.V6.PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. @@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } // TODO(b/112892170): Meaningfully handle all ICMP types. - switch h.Type() { + switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) @@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(packet.NDPPayload()) + na := header.NDPNeighborAdvert(packet.MessageBody()) // As per RFC 4861 section 7.2.4: // @@ -644,8 +644,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } + case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: + var handler func(header.MLD) + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + received.MulticastListenerQuery.Increment() + handler = e.mld.handleMulticastListenerQuery + case header.ICMPv6MulticastListenerReport: + received.MulticastListenerReport.Increment() + handler = e.mld.handleMulticastListenerReport + case header.ICMPv6MulticastListenerDone: + received.MulticastListenerDone.Increment() + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) + } + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { + received.Invalid.Increment() + return + } + + if handler != nil { + handler(header.MLD(payload.ToView())) + } + default: - received.Invalid.Increment() + received.Unrecognized.Increment() } } @@ -681,12 +704,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(packet.NDPPayload()) + ns := header.NDPNeighborSolicit(packet.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - stat := p.stack.Stats().ICMP.V6PacketsSent + stat := p.stack.Stats().ICMP.V6.PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, @@ -796,7 +819,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi allowResponseToMulticast = reason.respondToMulticast } - if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any { + isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst) + if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any { return nil } @@ -812,8 +836,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi // If we are operating as a router, do not use the packet's destination // address as the response's source address as we should not own the // destination address of a packet we are forwarding. + // + // If the packet was originally destined to a multicast address, then do not + // use the packet's destination address as the source for the response ICMP + // packet as "multicast addresses must not be used as source addresses in IPv6 + // packets", as per RFC 4291 section 2.7. localAddr := origIPHdrDst - if _, ok := reason.(*icmpReasonHopLimitExceeded); ok { + if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { localAddr = "" } // Even if we were able to receive a packet from some remote, we may not have @@ -827,7 +856,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi defer route.Release() stats := p.stack.Stats().ICMP - sent := stats.V6PacketsSent + sent := stats.V6.PacketsSent if !p.stack.AllowICMPMessage() { sent.RateLimited.Increment() return nil diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 9bc02d851..32adb5c83 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -150,9 +150,9 @@ func (*testInterface) Promiscuous() bool { func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + NetProto: protocol, } + r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -271,6 +271,22 @@ func TestICMPCounts(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -301,7 +317,7 @@ func TestICMPCounts(t *testing.T) { // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { if got, want := s.Value(), uint64(1); got != want { t.Errorf("got %s = %d, want = %d", name, got, want) @@ -413,6 +429,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -443,7 +475,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { if got, want := s.Value(), uint64(1); got != want { t.Errorf("got %s = %d, want = %d", name, got, want) @@ -568,8 +600,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. return } - if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { - t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) + if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr { + t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr) } // Pull the full payload since network header. Needed for header.IPv6 to @@ -833,7 +865,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) @@ -1028,7 +1060,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid typStat := typ.statCounter(stats) @@ -1207,7 +1239,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid typStat := typ.statCounter(stats) @@ -1349,8 +1381,8 @@ func TestLinkAddressRequest(t *testing.T) { if !ok { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) } if pkt.Route.RemoteAddress != test.expectedRemoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) @@ -1431,8 +1463,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1473,8 +1505,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1524,8 +1556,8 @@ func TestPacketQueing(t *testing.T) { t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) } snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1543,7 +1575,7 @@ func TestPacketQueing(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) @@ -1592,7 +1624,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1612,7 +1644,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1629,7 +1661,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1645,7 +1677,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1662,7 +1694,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1683,7 +1715,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1702,7 +1734,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1722,7 +1754,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7a00f6314..8bf84601f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -34,7 +34,9 @@ import ( ) const ( + // ReassembleTimeout controls how long a fragment will be held. // As per RFC 8200 section 4.5: + // // If insufficient fragments are received to complete reassembly of a packet // within 60 seconds of the reception of the first-arriving fragment of that // packet, reassembly of that packet must be abandoned. @@ -84,6 +86,8 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState } + + mld mldState } // NICNameFromID is a function that returns a stable name for the specified NIC, @@ -224,6 +228,12 @@ func (e *endpoint) Enable() *tcpip.Error { return nil } + // Groups may have been joined when the endpoint was disabled, or the + // endpoint may have left groups from the perspective of MLD when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + e.mld.initializeAll() + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -241,8 +251,10 @@ func (e *endpoint) Enable() *tcpip.Error { // (NDP NS) messages may be sent to the All-Nodes multicast group if the // source address of the NDP NS is the unspecified address, as per RFC 4861 // section 7.2.4. - if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { - return err + if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err)) } // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent @@ -251,7 +263,7 @@ func (e *endpoint) Enable() *tcpip.Error { // Addresses may have aleady completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. var err *tcpip.Error - e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { addr := addressEndpoint.AddressWithPrefix().Address if !header.IsV6UnicastAddress(addr) { return true @@ -273,7 +285,7 @@ func (e *endpoint) Enable() *tcpip.Error { } // Do not auto-generate an IPv6 link-local address for loopback devices. - if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() { + if e.protocol.options.AutoGenLinkLocal && !e.nic.IsLoopback() { // The valid and preferred lifetime is infinite for the auto-generated // link-local address. e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) @@ -331,9 +343,13 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } + + // Leave groups from the perspective of MLD so that routers know that + // we are no longer interested in the group. + e.mld.softLeaveAll() } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -341,7 +357,7 @@ func (e *endpoint) disableLocked() { // Precondition: e.mu must be write locked. func (e *endpoint) stopDADForPermanentAddressesLocked() { // Stop DAD for all the tentative unicast addresses. - e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { if addressEndpoint.GetKind() != stack.PermanentTentative { return true } @@ -376,7 +392,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { length := uint16(pkt.Size()) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -384,8 +400,8 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s NextHeader: uint8(params.Protocol), HopLimit: params.TTL, TrafficClass: params.TOS, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: srcAddr, + DstAddr: dstAddr, }) pkt.NetworkProtocolNumber = ProtocolNumber } @@ -440,7 +456,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) // iptables filtering. All packets that reach here are locally // generated. @@ -529,7 +545,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r, pb, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params) networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { @@ -737,8 +753,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } - addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) - if addressEndpoint == nil { + // The destination address should be an address we own or a group we joined + // for us to receive the packet. Otherwise, attempt to forward the packet. + if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { stats.IP.InvalidDestinationAddressesReceived.Increment() return @@ -747,7 +766,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { _ = e.forwardPacket(pkt) return } - addressEndpoint.DecRef() // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). @@ -1090,9 +1108,16 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. + // The location of the Next Header field is in a different place in + // the initial IPv6 header than it is in the extension headers so + // treat it specially. + prevHdrIDOffset := uint32(header.IPv6NextHeaderOffset) + if previousHeaderStart != 0 { + prevHdrIDOffset = previousHeaderStart + } _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, - pointer: it.ParseOffset(), + pointer: prevHdrIDOffset, }, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) @@ -1100,12 +1125,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(&icmpReasonParameterProblem{ - code: header.ICMPv6UnknownHeader, - pointer: it.ParseOffset(), - }, pkt) - stats.UnknownProtocolRcvdPackets.Increment() - return + // Since the iterator returns IPv6RawPayloadHeader for unknown Extension + // Header IDs this should never happen unless we missed a supported type + // here. + panic(fmt.Sprintf("unrecognized type from it.Next() = %T", extHdr)) + } } } @@ -1154,8 +1178,10 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { - return nil, err + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) } addressEndpoint.SetKind(stack.PermanentTentative) @@ -1211,7 +1237,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + // The endpoint may have already left the multicast group. + if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1234,7 +1261,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool { // // Precondition: e.mu must be read or write locked. func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint { - return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr) + return e.mu.addressableEndpointState.GetAddress(localAddr) } // MainAddress implements stack.AddressableEndpoint. @@ -1285,7 +1312,7 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // Create a candidate set of available addresses we can potentially use as a // source address. var cs []addrCandidate - e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { // If r is not valid for outgoing connections, it is not a valid endpoint. if !addressEndpoint.IsAssigned(allowExpired) { return @@ -1376,28 +1403,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return false, tcpip.ErrBadAddress + return tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + e.mld.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) @@ -1405,7 +1447,8 @@ var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { - stack *stack.Stack + stack *stack.Stack + options Options mu struct { sync.RWMutex @@ -1429,26 +1472,6 @@ type protocol struct { forwarding uint32 fragmentation *fragmentation.Fragmentation - - // ndpDisp is the NDP event dispatcher that is used to send the netstack - // integrator NDP related events. - ndpDisp NDPDispatcher - - // ndpConfigs is the default NDP configurations used by an IPv6 endpoint. - ndpConfigs NDPConfigurations - - // opaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. - opaqueIIDOpts OpaqueInterfaceIdentifierOptions - - // tempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - tempIIDSeed []byte - - // autoGenIPv6LinkLocal determines whether or not the stack attempts to - // auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. - autoGenIPv6LinkLocal bool } // Number returns the ipv6 protocol number. @@ -1484,13 +1507,14 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L e.mu.addressableEndpointState.Init(e) e.mu.ndp = ndpState{ ep: e, - configs: p.ndpConfigs, + configs: p.options.NDPConfigs, dad: make(map[tcpip.Address]dadState), defaultRouters: make(map[tcpip.Address]defaultRouterState), onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } e.mu.ndp.initializeTempAddrState() + e.mld.init(e, p.options.MLD) p.mu.Lock() defer p.mu.Unlock() @@ -1613,17 +1637,17 @@ type Options struct { // NDPConfigs is the default NDP configurations used by interfaces. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determines whether or not the stack attempts to - // auto-generate an IPv6 link-local address for newly enabled non-loopback + // AutoGenLinkLocal determines whether or not the stack attempts to + // auto-generate a link-local address for newly enabled non-loopback // NICs. // // Note, setting this to true does not mean that a link-local address is // assigned right away, or at all. If Duplicate Address Detection is enabled, // an address is only assigned if it successfully resolves. If it fails, no - // further attempts are made to auto-generate an IPv6 link-local adddress. + // further attempts are made to auto-generate a link-local adddress. // // The generated link-local address follows RFC 4291 Appendix A guidelines. - AutoGenIPv6LinkLocal bool + AutoGenLinkLocal bool // NDPDisp is the NDP event dispatcher that an integrator can provide to // receive NDP related events. @@ -1647,6 +1671,9 @@ type Options struct { // seed that is too small would reduce randomness and increase predictability, // defeating the purpose of temporary SLAAC addresses. TempIIDSeed []byte + + // MLD holds options for MLD. + MLD MLDOptions } // NewProtocolWithOptions returns an IPv6 network protocol. @@ -1658,15 +1685,11 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ - stack: s, + stack: s, + options: opts, + ids: ids, hashIV: hashIV, - - ndpDisp: opts.NDPDisp, - ndpConfigs: opts.NDPConfigs, - opaqueIIDOpts: opts.OpaqueIIDOpts, - tempIIDSeed: opts.TempIIDSeed, - autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, } p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[*endpoint]struct{}) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index a671d4bac..1c01f17ab 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -51,6 +51,7 @@ const ( fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) + unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier) extraHeaderReserve = 50 ) @@ -79,7 +80,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst Data: hdr.View().ToVectorisedView(), })) - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived if got := stats.NeighborAdvert.Value(); got != want { t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) @@ -573,6 +574,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { expectICMP: false, }, { + name: "unknown next header (first)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, 63, 4, 1, 2, 3, 4, + }, unknownHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6NextHeaderOffset, + }, + { + name: "unknown next header (not first)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + unknownHdrID, 0, + 63, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { name: "destination with unknown option skippable action", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ @@ -755,11 +783,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { pointer: header.IPv6FixedHeaderSize, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, - shouldAccept: false, - }, - { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ @@ -873,7 +896,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { Length: uint16(udpLength), }) copy(u.Payload(), udpPayload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } + + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength)) sum = header.Checksum(udpPayload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) @@ -884,10 +913,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - dstAddr := tcpip.Address(addr2) - if test.multicast { - dstAddr = header.IPv6AllNodesMulticastAddress - } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: ipv6NextHdr, @@ -982,9 +1007,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { udpPayload2Length = 128 // Used to test cases where the fragment blocks are not a multiple of // the fragment block size of 8 (RFC 8200 section 4.5). - udpPayload3Length = 127 - udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize - fragmentExtHdrLen = 8 + udpPayload3Length = 127 + udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize + udpMaximumSizeMinus15 = header.UDPMaximumSize - 15 + fragmentExtHdrLen = 8 // Note, not all routing extension headers will be 8 bytes but this test // uses 8 byte routing extension headers for most sub tests. routingExtHdrLen = 8 @@ -1328,14 +1354,14 @@ func TestReceiveIPv6Fragments(t *testing.T) { dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+65520, + fragmentExtHdrLen+udpMaximumSizeMinus15, []buffer.View{ // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload4Addr1ToAddr2[:65520], + ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, ), }, @@ -1344,14 +1370,17 @@ func TestReceiveIPv6Fragments(t *testing.T) { dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520, + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, []buffer.View{ // Fragment extension header. // - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}), + // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + udpMaximumSizeMinus15 >> 8, + udpMaximumSizeMinus15 & 0xff, + 0, 0, 0, 1}), - ipv6Payload4Addr1ToAddr2[65520:], + ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, ), }, @@ -1359,6 +1388,47 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, { + name: "Two fragments with MF flag reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+udpMaximumSizeMinus15, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + udpMaximumSizeMinus15 >> 8, + (udpMaximumSizeMinus15 & 0xff) + 1, + 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { name: "Two fragments with per-fragment routing header with zero segments left", fragments: []fragmentData{ { @@ -2439,7 +2509,7 @@ func TestWriteStats(t *testing.T) { test.setup(t, rt.Stack()) - nWritten, _ := writer.writePackets(&rt, pkts) + nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) @@ -2456,7 +2526,7 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go new file mode 100644 index 000000000..4c06b3f0c --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld.go @@ -0,0 +1,164 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited MLD reports. + // + // Obtained from RFC 2710 Section 7.10. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// MLDOptions holds options for MLD. +type MLDOptions struct { + // Enabled indicates whether MLD will be performed. + // + // When enabled, MLD may transmit MLD report and done messages when + // joining and leaving multicast groups respectively, and handle incoming + // MLD packets. + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*mldState)(nil) + +// mldState is the per-interface MLD state. +// +// mldState.init MUST be called to initialize the MLD state. +type mldState struct { + // The IPv6 endpoint this mldState is for. + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState +} + +// SendReport implements ip.MulticastGroupProtocol. +func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error { + return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) +} + +// SendLeave implements ip.MulticastGroupProtocol. +func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) +} + +// init sets up an mldState struct, and is required to be called before using +// a new mldState. +func (mld *mldState) init(ep *endpoint, opts MLDOptions) { + mld.ep = ep + mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ + Enabled: opts.Enabled, + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: mld, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv6AllNodesMulticastAddress, + }) +} + +func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { + mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) +} + +func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { + mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress()) +} + +// joinGroup handles joining a new group and sending and scheduling the required +// messages. +// +// If the group is already joined, returns tcpip.ErrDuplicateAddress. +func (mld *mldState) joinGroup(groupAddress tcpip.Address) { + mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */) +} + +// isInGroup returns true if the specified group has been joined locally. +func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { + return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress) +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Done message, if +// required. +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if mld.genericMulticastProtocol.LeaveGroup(groupAddress) { + return nil + } + + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of MLD, but remains +// joined locally. +func (mld *mldState) softLeaveAll() { + mld.genericMulticastProtocol.MakeAllNonMember() +} + +// initializeAll attemps to initialize the MLD state for each group that has +// been joined locally. +func (mld *mldState) initializeAll() { + mld.genericMulticastProtocol.InitializeGroups() +} + +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { + sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + var mldStat *tcpip.StatCounter + switch mldType { + case header.ICMPv6MulticastListenerReport: + mldStat = sentStats.MulticastListenerReport + case header.ICMPv6MulticastListenerDone: + mldStat = sentStats.MulticastListenerDone + default: + panic(fmt.Sprintf("unrecognized mld type = %d", mldType)) + } + + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) + icmp.SetType(mldType) + header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) + // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, + // rather we should select an appropriate local address. + localAddress := header.IPv6Any + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()), + Data: buffer.View(icmp).ToVectorisedView(), + }) + + mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.MLDHopLimit, + }) + // TODO(b/162198658): set the ROUTER_ALERT option when sending Host + // Membership Reports. + if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sentStats.Dropped.Increment() + return err + } + mldStat.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go new file mode 100644 index 000000000..5677bdd54 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" +) + +func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + // The stack will join an address's solicited node multicast address when + // an address is added. An MLD report message should be sent for the + // solicited-node group. + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err) + } + { + p, ok := e.Read() + if !ok { + t.Fatal("expected a report message to be sent") + } + snmc := header.SolicitedNodeAddr(addr1) + checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), + checker.DstAddr(snmc), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(snmc), + ), + ) + } + + // The stack will leave an address's solicited node multicast address when + // an address is removed. An MLD done message should be sent for the + // solicited-node group. + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + } + { + p, ok := e.Read() + if !ok { + t.Fatal("expected a done message to be sent") + } + snmc := header.SolicitedNodeAddr(addr1) + checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(1), + checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(snmc), + ), + ) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 40da011f8..8cb7d4dab 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -471,17 +471,8 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - rtrSolicit struct { - // The timer used to send the next router solicitation message. - timer tcpip.Timer - - // Used to let the Router Solicitation timer know that it has been stopped. - // - // Must only be read from or written to while protected by the lock of - // the IPv6 endpoint this ndpState is associated with. MUST be set when the - // timer is set. - done *bool - } + // The job used to send the next router solicitation message. + rtrSolicitJob *tcpip.Job // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -507,7 +498,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer tcpip.Timer + job *tcpip.Job // Used to let the DAD timer know that it has been stopped. // @@ -648,96 +639,70 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // Consider DAD to have resolved even if no DAD messages were actually // transmitted. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } return nil } - var done bool - var timer tcpip.Timer - // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the IPv6 endpoint's lock. This is effectively - // the same as starting a goroutine but we use a timer that fires immediately - // so we can reset it for the next DAD iteration. - timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { - ndp.ep.mu.Lock() - defer ndp.ep.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD timer fired after - // another goroutine already obtained the IPv6 endpoint lock and stopped - // DAD before this function obtained the NIC lock. Simply return here and - // do nothing further. - return - } - - if addressEndpoint.GetKind() != stack.PermanentTentative { - // The endpoint should still be marked as tentative since we are still - // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) - } + state := dadState{ + job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + state, ok := ndp.dad[addr] + if !ok { + panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID())) + } - dadDone := remaining == 0 - - var err *tcpip.Error - if !dadDone { - // Use the unspecified address as the source address when performing DAD. - addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) - - // Do not hold the lock when sending packets which may be a long running - // task or may block link address resolution. We know this is safe - // because immediately after obtaining the lock again, we check if DAD - // has been stopped before doing any work with the IPv6 endpoint. Note, - // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if - // the address was removed. - ndp.ep.mu.Unlock() - err = ndp.sendDADPacket(addr, addressEndpoint) - ndp.ep.mu.Lock() - addressEndpoint.DecRef() - } + if addressEndpoint.GetKind() != stack.PermanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) + } - if done { - // If we reach this point, it means that DAD was stopped after we released - // the IPv6 endpoint's read lock and before we obtained the write lock. - return - } + dadDone := remaining == 0 - if dadDone { - // DAD has resolved. - addressEndpoint.SetKind(stack.Permanent) - } else if err == nil { - // DAD is not done and we had no errors when sending the last NDP NS, - // schedule the next DAD timer. - remaining-- - timer.Reset(ndp.configs.RetransmitTimer) - return - } + var err *tcpip.Error + if !dadDone { + err = ndp.sendDADPacket(addr, addressEndpoint) + } - // At this point we know that either DAD is done or we hit an error sending - // the last NDP NS. Either way, clean up addr's DAD state and let the - // integrator know DAD has completed. - delete(ndp.dad, addr) + if dadDone { + // DAD has resolved. + addressEndpoint.SetKind(stack.Permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. + remaining-- + state.job.Schedule(ndp.configs.RetransmitTimer) + return + } - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) - } + // At this point we know that either DAD is done or we hit an error + // sending the last NDP NS. Either way, clean up addr's DAD state and let + // the integrator know DAD has completed. + delete(ndp.dad, addr) - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) - } - }) + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) + } - ndp.dad[addr] = dadState{ - timer: timer, - done: &done, + // If DAD resolved for a stable SLAAC address, attempt generation of a + // temporary SLAAC address. + if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } + }), } + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the IPv6 endpoint's lock. This is effectively + // the same as starting a goroutine but we use a timer that fires immediately + // so we can reset it for the next DAD iteration. + state.job.Schedule(0) + ndp.dad[addr] = state + return nil } @@ -745,55 +710,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // addr. // // addr must be a tentative IPv6 address on ndp's IPv6 endpoint. -// -// The IPv6 endpoint that ndp belongs to MUST NOT be locked. func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - - // Route should resolve immediately since snmc is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the IPv6 endpoint is not - // locked, the NIC could have been disabled or removed by another goroutine. - if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { - return err - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) - } - - icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) - icmpData.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(addr) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(icmpData).ToVectorisedView(), + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), + Data: buffer.View(icmp).ToVectorisedView(), }) - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt, - ); err != nil { + sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }) + + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() return err } sent.NeighborSolicit.Increment() - return nil } @@ -812,18 +753,11 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { return } - if dad.timer != nil { - dad.timer.Stop() - dad.timer = nil - - *dad.done = true - dad.done = nil - } - + dad.job.Cancel() delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) } } @@ -846,7 +780,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we // only inform the dispatcher on configuration changes. We do nothing else // with the information. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { var configuration DHCPv6ConfigurationFromNDPRA switch { case ra.ManagedAddrConfFlag(): @@ -903,20 +837,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { switch opt := opt.(type) { case header.NDPRecursiveDNSServer: - if ndp.ep.protocol.ndpDisp == nil { + if ndp.ep.protocol.options.NDPDisp == nil { continue } addrs, _ := opt.Addresses() - ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) + ndp.ep.protocol.options.NDPDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) case header.NDPDNSSearchList: - if ndp.ep.protocol.ndpDisp == nil { + if ndp.ep.protocol.options.NDPDisp == nil { continue } domainNames, _ := opt.DomainNames() - ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) + ndp.ep.protocol.options.NDPDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -964,7 +898,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) } } @@ -976,7 +910,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } @@ -1006,7 +940,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } @@ -1047,7 +981,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) } } @@ -1225,7 +1159,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint { // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return nil } @@ -1272,7 +1206,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt } dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures - if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil { + if oIID := ndp.ep.protocol.options.OpaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, @@ -1676,7 +1610,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint } addressEndpoint.SetDeprecated(true) - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } } @@ -1701,7 +1635,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } @@ -1761,7 +1695,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } @@ -1859,7 +1793,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicit.timer != nil { + if ndp.rtrSolicitJob != nil { // We are already soliciting routers. return } @@ -1876,56 +1810,14 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - var done bool - ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { - ndp.ep.mu.Lock() - if done { - // If we reach this point, it means that the RS timer fired after another - // goroutine already obtained the IPv6 endpoint lock and stopped - // solicitations. Simply return here and do nothing further. - ndp.ep.mu.Unlock() - return - } - + ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false) - if addressEndpoint == nil { - // Incase this ends up creating a new temporary address, we need to hold - // onto the endpoint until a route is obtained. If we decrement the - // reference count before obtaing a route, the address's resources would - // be released and attempting to obtain a route after would fail. Once a - // route is obtainted, it is safe to decrement the reference count since - // obtaining a route increments the address's reference count. - addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) - } - ndp.ep.mu.Unlock() - - localAddr := addressEndpoint.AddressWithPrefix().Address - r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) - addressEndpoint.DecRef() - if err != nil { - return - } - defer r.Release() - - // Route should resolve immediately since - // header.IPv6AllRoutersMulticastAddress is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the IPv6 endpoint is - // not locked, the IPv6 endpoint could have been disabled or removed by - // another goroutine. - if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { - return - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) + localAddr := header.IPv6Any + if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + localAddr = addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1936,30 +1828,31 @@ func (ndp *ndpState) startSolicitingRouters() { // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. var optsSerializer header.NDPOptionsSerializer - if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + linkAddress := ndp.ep.nic.LinkAddress() + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) { optsSerializer = header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + header.NDPSourceLinkLayerAddressOption(linkAddress), } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(icmpData.NDPPayload()) + rs := header.NDPRouterSolicit(icmpData.MessageBody()) rs.Options().Serialize(optsSerializer) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt, - ); err != nil { + sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }) + + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. @@ -1969,21 +1862,12 @@ func (ndp *ndpState) startSolicitingRouters() { remaining-- } - ndp.ep.mu.Lock() - if done || remaining == 0 { - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil - } else if ndp.rtrSolicit.timer != nil { - // Note, we need to explicitly check to make sure that - // the timer field is not nil because if it was nil but - // we still reached this point, then we know the IPv6 endpoint - // was requested to stop soliciting routers so we don't - // need to send the next Router Solicitation message. - ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) + if remaining != 0 { + ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval) } - ndp.ep.mu.Unlock() }) + ndp.rtrSolicitJob.Schedule(delay) } // stopSolicitingRouters stops soliciting routers. If routers are not currently @@ -1991,21 +1875,19 @@ func (ndp *ndpState) startSolicitingRouters() { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicit.timer == nil { + if ndp.rtrSolicitJob == nil { // Nothing to do. return } - *ndp.rtrSolicit.done = true - ndp.rtrSolicit.timer.Stop() - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil + ndp.rtrSolicitJob.Cancel() + ndp.rtrSolicitJob = nil } // initializeTempAddrState initializes state related to temporary SLAAC // addresses. func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 37e8b1083..95c626bb8 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -205,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -220,7 +220,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -311,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -326,16 +326,16 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) neighbors, err := s.Neighbors(nicID) if err != nil { @@ -591,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(nicAddr) opts := ns.Options() opts.Serialize(test.nsOpts) @@ -606,7 +606,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { DstAddr: test.nsDst, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != respNSDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) } - if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -672,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(test.nsSrc) @@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != test.naDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -777,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -792,7 +792,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -890,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -905,16 +905,16 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) neighbors, err := s.Neighbors(nicID) if err != nil { @@ -1122,7 +1122,7 @@ func TestNDPValidation(t *testing.T) { s.SetForwarding(ProtocolNumber, true) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) @@ -1346,7 +1346,7 @@ func TestRouterAdvertValidation(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) + copy(pkt.MessageBody(), test.ndpPayload) payloadLength := hdr.UsedLength() pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1358,7 +1358,7 @@ func TestRouterAdvertValidation(t *testing.T) { DstAddr: header.IPv6AllNodesMulticastAddress, }) - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid rxRA := stats.RouterAdvert diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go new file mode 100644 index 000000000..95fb67986 --- /dev/null +++ b/pkg/tcpip/network/multicast_group_test.go @@ -0,0 +1,1069 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "fmt" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") + ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") + ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") + ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") + ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") + + igmpMembershipQuery = uint8(header.IGMPMembershipQuery) + igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) + igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) + igmpLeaveGroup = uint8(header.IGMPLeaveGroup) + mldQuery = uint8(header.ICMPv6MulticastListenerQuery) + mldReport = uint8(header.ICMPv6MulticastListenerReport) + mldDone = uint8(header.ICMPv6MulticastListenerDone) +) + +var ( + // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the + // NIC will wait before sending an unsolicited report after joining a + // multicast group, in deciseconds. + unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { + const decisecond = time.Second / 10 + if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { + panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) + } + return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) + }() +) + +// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet +// sent to the provided address with the passed fields set. +func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv6(t, payload, + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, + checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + +// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. +func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.DstAddr(remoteAddress), + // TTL for an IGMP message must be 1 as per RFC 2236 section 2. + checker.TTL(1), + checker.IPv4RouterAlert(), + checker.IGMP( + checker.IGMPType(header.IGMPType(igmpType)), + checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + // Create an endpoint of queue size 2, since no more than 2 packets are ever + // queued in the tests in this file. + e := channel.New(2, 1280, linkAddr) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMP: ipv4.IGMPOptions{ + Enabled: mgpEnabled, + }, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: mgpEnabled, + }, + }), + }, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + return e, s, clock +} + +// createAndInjectIGMPPacket creates and injects an IGMP packet with the +// specified fields. +// +// Note, the router alert option is not included in this packet. +// +// TODO(b/162198658): set the router alert option. +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: header.IGMPTTL, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(header.IGMPType(igmpType)) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// createAndInjectMLDPacket creates and injects an MLD packet with the +// specified fields. +// +// Note, the router alert option is not included in this packet. +// +// TODO(b/162198658): set the router alert option. +func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) { + icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize + buf := buffer.NewView(header.IPv6MinimumSize + icmpSize) + + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + HopLimit: header.MLDHopLimit, + NextHeader: uint8(header.ICMPv6ProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + icmp := header.ICMPv6(buf[header.IPv6MinimumSize:]) + icmp.SetType(header.ICMPv6Type(mldType)) + mld := header.MLD(icmp.MessageBody()) + mld.SetMaximumResponseDelay(uint16(maxRespDelay)) + mld.SetMulticastAddress(groupAddress) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + + e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestMGPDisabled tests that the multicast group protocol is not enabled by +// default. +func TestMGPDisabled(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + rxQuery func(*channel.Endpoint) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxQuery: func(e *channel.Endpoint) { + createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxQuery: func(e *channel.Endpoint) { + createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, false) + + // This NIC may join multicast groups when it is enabled but since MGP is + // disabled, no reports should be sent. + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt) + } + + // Inject a general query message. This should only trigger a report to be + // sent if the MGP was enabled. + test.rxQuery(e) + if got := test.receivedQueryStat(s).Value(); got != 1 { + t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } + }) + } +} + +func TestMGPReceiveCounters(t *testing.T) { + tests := []struct { + name string + headerType uint8 + maxRespTime byte + groupAddress tcpip.Address + statCounter func(*stack.Stack) *tcpip.StatCounter + rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address) + }{ + { + name: "IGMP Membership Query", + headerType: igmpMembershipQuery, + maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, + groupAddress: header.IPv4Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMPv1 Membership Report", + headerType: igmpv1MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.V1MembershipReport + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMPv2 Membership Report", + headerType: igmpv2MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.V2MembershipReport + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMP Leave Group", + headerType: igmpLeaveGroup, + maxRespTime: 0, + groupAddress: header.IPv4AllRoutersGroup, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.LeaveGroup + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "MLD Query", + headerType: mldQuery, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxMGPkt: createAndInjectMLDPacket, + }, + { + name: "MLD Report", + headerType: mldReport, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport + }, + rxMGPkt: createAndInjectMLDPacket, + }, + { + name: "MLD Done", + headerType: mldDone, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone + }, + rxMGPkt: createAndInjectMLDPacket, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, _ := createStack(t, true) + + test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) + if got := test.statCounter(s).Value(); got != 1 { + t.Fatalf("got %s received = %d, want = 1", test.name, got) + } + }) + } +} + +// TestMGPJoinGroup tests that when explicitly joining a multicast group, the +// stack schedules and sends correct Membership Reports. +func TestMGPJoinGroup(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, true) + + // Test joining a specific address explicitly and verify a Report is sent + // immediately. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 1 { + t.Errorf("got sentReportState.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Verify the second report is sent by the maximum unsolicited response + // interval. + p, ok := e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) + } + clock.Advance(test.maxUnsolicitedResponseDelay) + if got := sentReportStat.Value(); got != 2 { + t.Errorf("got sentReportState.Value() = %d, want = 2", got) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPLeaveGroup tests that when leaving a previously joined multicast +// group the stack sends a leave/done message. +func TestMGPLeaveGroup(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + validateLeave func(*testing.T, channel.PacketInfo) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := test.sentReportStat(s).Value(); got != 1 { + t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Leaving the group should trigger an leave/done message to be sent. + if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) + } + if got := test.sentLeaveStat(s).Value(); got != 1 { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a leave message to be sent") + } else { + test.validateLeave(t, p) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPQueryMessages tests that a report is sent in response to query +// messages. +func TestMGPQueryMessages(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + rxQuery func(*channel.Endpoint, uint8, tcpip.Address) + validateReport func(*testing.T, channel.PacketInfo) + maxRespTimeToDuration func(uint8) time.Duration + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { + createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + maxRespTimeToDuration: header.DecisecondToDuration, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { + createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + maxRespTimeToDuration: func(d uint8) time.Duration { + return time.Duration(d) * time.Millisecond + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + multicastAddr tcpip.Address + expectReport bool + }{ + { + name: "Unspecified", + multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))), + expectReport: true, + }, + { + name: "Specified", + multicastAddr: test.multicastAddr, + expectReport: true, + }, + { + name: "Specified other address", + multicastAddr: func() tcpip.Address { + addrBytes := []byte(test.multicastAddr) + addrBytes[len(addrBytes)-1]++ + return tcpip.Address(addrBytes) + }(), + expectReport: false, + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + for i := uint64(1); i <= 2; i++ { + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != i { + t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected %d-th report message to be sent", i) + } else { + test.validateReport(t, p) + } + clock.Advance(test.maxUnsolicitedResponseDelay) + } + if t.Failed() { + t.FailNow() + } + + // Should not send any more packets until a query. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + // Receive a query message which should trigger a report to be sent at + // some time before the maximum response time if the report is + // targeted at the host. + const maxRespTime = 100 + test.rxQuery(e, maxRespTime, subTest.multicastAddr) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p.Pkt) + } + + if subTest.expectReport { + clock.Advance(test.maxRespTimeToDuration(maxRespTime)) + if got := sentReportStat.Value(); got != 3 { + t.Errorf("got sentReportState.Value() = %d, want = 3", got) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } + }) + } +} + +// TestMGPQueryMessages tests that no further reports or leave/done messages +// are sent after receiving a report. +func TestMGPReportMessages(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + rxReport func(*channel.Endpoint) + validateReport func(*testing.T, channel.PacketInfo) + maxRespTimeToDuration func(uint8) time.Duration + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + rxReport: func(e *channel.Endpoint) { + createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + maxRespTimeToDuration: header.DecisecondToDuration, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + rxReport: func(e *channel.Endpoint) { + createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + maxRespTimeToDuration: func(d uint8) time.Duration { + return time.Duration(d) * time.Millisecond + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, true) + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 1 { + t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Receiving a report for a group we joined should cancel any further + // reports. + test.rxReport(e) + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 1 { + t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); ok { + t.Errorf("sent unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Leaving a group after getting a report should not send a leave/done + // message. + if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) + } + clock.Advance(time.Hour) + if got := test.sentLeaveStat(s).Value(); got != 0 { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +func TestMGPWithNICLifecycle(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddrs []tcpip.Address + finalMulticastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) + validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) + getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, + finalMulticastAddr: ipv4MulticastAddr3, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { + t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) + } + addr := header.IGMP(ipv4.Payload()).GroupAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, + finalMulticastAddr: ipv6MulticastAddr3, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, addr, mldReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + if got := tcpip.TransportProtocolNumber(ipv6.NextHeader()); got != header.ICMPv6ProtocolNumber { + t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) + } + icmpv6 := header.ICMPv6(ipv6.Payload()) + if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { + t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) + } + addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, true) + + sentReportStat := test.sentReportStat(s) + var reportCounter uint64 + for _, a := range test.multicastAddrs { + if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected a report message to be sent for %s", a) + } else { + test.validateReport(t, p, a) + } + } + if t.Failed() { + t.FailNow() + } + + // Leave messages should be sent for the joined groups when the NIC is + // disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + sentLeaveStat := test.sentLeaveStat(s) + leaveCounter := uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i, _ := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + + test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Reports should be sent for the joined groups when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter += uint64(len(test.multicastAddrs)) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i, _ := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) report message to be sent", i) + } + + test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Joining/leaving a group while disabled should not send any messages. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + leaveCounter += uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + for i, _ := range test.multicastAddrs { + if _, ok := e.Read(); !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + } + for _, a := range test.multicastAddrs { + if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) + } + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) + } + } + if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) + } + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) + } + + // A report should only be sent for the group we last joined after + // enabling the NIC since the original groups were all left. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + clock.Advance(test.maxUnsolicitedResponseDelay) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 7cc52985e..5c3363759 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st return n, nil } -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - - return nil -} - // Attach implements LinkEndpoint.Attach. func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} |