summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp_test.go12
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go6
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go95
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go174
-rw-r--r--pkg/tcpip/network/ip/BUILD25
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go546
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go576
-rw-r--r--pkg/tcpip/network/ip_test.go38
-rw-r--r--pkg/tcpip/network/ipv4/BUILD7
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go323
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go156
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go230
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go277
-rw-r--r--pkg/tcpip/network/ipv6/BUILD15
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go49
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go84
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go157
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go114
-rw-r--r--pkg/tcpip/network/ipv6/mld.go164
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go90
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go332
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go44
-rw-r--r--pkg/tcpip/network/multicast_group_test.go1069
-rw-r--r--pkg/tcpip/network/testutil/testutil.go15
26 files changed, 3884 insertions, 722 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index b38aff0b8..9ebf31b78 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -7,12 +7,14 @@ go_test(
size = "small",
srcs = [
"ip_test.go",
+ "multicast_group_test.go",
],
deps = [
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index f462524c9..0fb373612 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -319,9 +319,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
copy(h.HardwareAddressSender(), test.senderLinkAddr)
copy(h.ProtocolAddressSender(), test.senderAddr)
copy(h.ProtocolAddressTarget(), test.targetAddr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: v.ToVectorisedView(),
- })
+ }))
if !test.isValid {
// No packets should be sent after receiving an invalid ARP request.
@@ -442,9 +442,9 @@ func (*testInterface) Promiscuous() bool {
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -557,8 +557,8 @@ func TestLinkAddressRequest(t *testing.T) {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr)
}
rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index c75ca7d71..d31296a41 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -46,9 +46,13 @@ const (
)
var (
- // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // ErrInvalidArgs indicates to the caller that an invalid argument was
// provided.
ErrInvalidArgs = errors.New("invalid args")
+
+ // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
+ // with another one.
+ ErrFragmentOverlap = errors.New("overlapping fragments")
)
// FragmentID is the identifier for a fragment.
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 19f4920b3..04072d966 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -26,9 +26,9 @@ import (
)
type hole struct {
- first uint16
- last uint16
- deleted bool
+ first uint16
+ last uint16
+ filled bool
}
type reassembler struct {
@@ -38,7 +38,7 @@ type reassembler struct {
proto uint8
mu sync.Mutex
holes []hole
- deleted int
+ filled int
heap fragHeap
done bool
creationTime int64
@@ -53,44 +53,86 @@ func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
- first: 0,
- last: math.MaxUint16,
- deleted: false})
+ first: 0,
+ last: math.MaxUint16,
+ filled: false,
+ })
return r
}
-// updateHoles updates the list of holes for an incoming fragment and
-// returns true iff the fragment filled at least part of an existing hole.
-func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
- used := false
+// updateHoles updates the list of holes for an incoming fragment. It returns
+// true if the fragment fits, it is not a duplicate and it does not overlap with
+// another fragment.
+//
+// For IPv6, overlaps with an existing fragment are explicitly forbidden by
+// RFC 8200 section 4.5:
+// If any of the fragments being reassembled overlap with any other fragments
+// being reassembled for the same packet, reassembly of that packet must be
+// abandoned and all the fragments that have been received for that packet
+// must be discarded, and no ICMP error messages should be sent.
+//
+// It is not explicitly forbidden for IPv4, but to keep parity with Linux we
+// disallow it as well:
+// https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
+func (r *reassembler) updateHoles(first, last uint16, more bool) (bool, error) {
for i := range r.holes {
- if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
+ currentHole := &r.holes[i]
+
+ if currentHole.filled || last < currentHole.first || currentHole.last < first {
continue
}
- used = true
- r.deleted++
- r.holes[i].deleted = true
- if first > r.holes[i].first {
- r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
+
+ if first < currentHole.first || currentHole.last < last {
+ // Incoming fragment only partially fits in the free hole.
+ return false, ErrFragmentOverlap
+ }
+
+ r.filled++
+ if first > currentHole.first {
+ r.holes = append(r.holes, hole{
+ first: currentHole.first,
+ last: first - 1,
+ filled: false,
+ })
+ }
+ if last < currentHole.last && more {
+ r.holes = append(r.holes, hole{
+ first: last + 1,
+ last: currentHole.last,
+ filled: false,
+ })
}
- if last < r.holes[i].last && more {
- r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
+ // Update the current hole to precisely match the incoming fragment.
+ r.holes[i] = hole{
+ first: first,
+ last: last,
+ filled: true,
}
+ return true, nil
}
- return used
+
+ // Incoming fragment is a duplicate/subset, or its offset comes after the end
+ // of the reassembled payload.
+ return false, nil
}
func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
- consumed := 0
if r.done {
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, 0, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, 0, nil
}
- if r.updateHoles(first, last, more) {
+
+ used, err := r.updateHoles(first, last, more)
+ if err != nil {
+ return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err)
+ }
+
+ var consumed int
+ if used {
// For IPv6, it is possible to have different Protocol values between
// fragments of a packet (because, unlike IPv4, the Protocol is not used to
// identify a fragment). In this case, only the Protocol of the first
@@ -109,13 +151,14 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
consumed = vv.Size()
r.size += consumed
}
- // Check if all the holes have been deleted and we are ready to reassamble.
- if r.deleted < len(r.holes) {
+
+ // Check if all the holes have been filled and we are ready to reassemble.
+ if r.filled < len(r.holes) {
return buffer.VectorisedView{}, 0, false, consumed, nil
}
res, err := r.heap.reassemble()
if err != nil {
- return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err)
+ return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err)
}
return res, r.proto, true, consumed, nil
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index a0a04a027..cee3063b1 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -16,92 +16,124 @@ package fragmentation
import (
"math"
- "reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
)
-type updateHolesInput struct {
- first uint16
- last uint16
- more bool
+type updateHolesParams struct {
+ first uint16
+ last uint16
+ more bool
+ wantUsed bool
+ wantError error
}
-var holesTestCases = []struct {
- comment string
- in []updateHolesInput
- want []hole
-}{
- {
- comment: "No fragments. Expected holes: {[0 -> inf]}.",
- in: []updateHolesInput{},
- want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
- },
- {
- comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
- in: []updateHolesInput{{first: 0, last: 1, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: false},
+func TestUpdateHoles(t *testing.T) {
+ var tests = []struct {
+ name string
+ params []updateHolesParams
+ want []hole
+ }{
+ {
+ name: "No fragments",
+ params: nil,
+ want: []hole{{first: 0, last: math.MaxUint16, filled: false}},
},
- },
- {
- comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
- {first: 3, last: math.MaxUint16, deleted: false},
+ {
+ name: "One fragment at beginning",
+ params: []updateHolesParams{{first: 0, last: 1, more: true, wantUsed: true, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true},
+ {first: 2, last: math.MaxUint16, filled: false},
+ },
},
- },
- {
- comment: "One fragment at the end. Expected holes: {[0, 0]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
+ {
+ name: "One fragment in the middle",
+ params: []updateHolesParams{{first: 1, last: 2, more: true, wantUsed: true, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true},
+ {first: 0, last: 0, filled: false},
+ {first: 3, last: math.MaxUint16, filled: false},
+ },
},
- },
- {
- comment: "One fragment completing a packet. Expected holes: {}.",
- in: []updateHolesInput{{first: 0, last: 1, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
+ {
+ name: "One fragment at the end",
+ params: []updateHolesParams{{first: 1, last: 2, more: false, wantUsed: true, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true},
+ {first: 0, last: 0, filled: false},
+ },
},
- },
- {
- comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 1, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "One fragment completing a packet",
+ params: []updateHolesParams{{first: 0, last: 1, more: false, wantUsed: true, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet",
+ params: []updateHolesParams{
+ {first: 0, last: 1, more: true, wantUsed: true, wantError: nil},
+ {first: 2, last: 3, more: false, wantUsed: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true},
+ {first: 2, last: 3, filled: true},
+ },
},
- },
- {
- comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 2, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "Two fragments completing a packet with a duplicate",
+ params: []updateHolesParams{
+ {first: 0, last: 1, more: true, wantUsed: true, wantError: nil},
+ {first: 0, last: 1, more: true, wantUsed: false, wantError: nil},
+ {first: 2, last: 3, more: false, wantUsed: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true},
+ {first: 2, last: 3, filled: true},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 3, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two overlapping fragments",
+ params: []updateHolesParams{
+ {first: 0, last: 10, more: true, wantUsed: true, wantError: nil},
+ {first: 5, last: 15, more: false, wantUsed: false, wantError: ErrFragmentOverlap},
+ {first: 11, last: 15, more: false, wantUsed: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 10, filled: true},
+ {first: 11, last: 15, filled: true},
+ },
},
- },
-}
+ {
+ name: "Out of bounds fragment",
+ params: []updateHolesParams{
+ {first: 0, last: 10, more: true, wantUsed: true, wantError: nil},
+ {first: 11, last: 15, more: false, wantUsed: true, wantError: nil},
+ {first: 16, last: 20, more: false, wantUsed: false, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 10, filled: true},
+ {first: 11, last: 15, filled: true},
+ },
+ },
+ }
-func TestUpdateHoles(t *testing.T) {
- for _, c := range holesTestCases {
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- for _, i := range c.in {
- r.updateHoles(i.first, i.last, i.more)
- }
- if !reflect.DeepEqual(r.holes, c.want) {
- t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ for _, param := range test.params {
+ used, err := r.updateHoles(param.first, param.last, param.more)
+ if used != param.wantUsed || err != param.wantError {
+ t.Errorf("got r.updateHoles(%d, %d, %t) = (%t, %v), want = (%t, %v)", param.first, param.last, param.more, used, err, param.wantUsed, param.wantError)
+ }
+ }
+ if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
new file mode 100644
index 000000000..6ca200b48
--- /dev/null
+++ b/pkg/tcpip/network/ip/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ip",
+ srcs = ["generic_multicast_protocol.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ ],
+)
+
+go_test(
+ name = "ip_test",
+ size = "small",
+ srcs = ["generic_multicast_protocol_test.go"],
+ deps = [
+ ":ip",
+ "//pkg/tcpip",
+ "//pkg/tcpip/faketime",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
new file mode 100644
index 000000000..f14e2a88a
--- /dev/null
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -0,0 +1,546 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ip holds IPv4/IPv6 common utilities.
+package ip
+
+import (
+ "fmt"
+ "math/rand"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// hostState is the state a host may be in for a multicast group.
+type hostState int
+
+// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1
+// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
+// and MLDv1, IGMPv2 terminology will be used.
+const (
+ // nonMember is the "'Non-Member' state, when the host does not belong to the
+ // group on the interface. This is the initial state for all memberships on
+ // all network interfaces; it requires no storage in the host."
+ //
+ // 'Non-Listener' is the MLDv1 term used to describe this state.
+ //
+ // This state is used to keep track of groups that have been joined locally,
+ // but without advertising the membership to the network.
+ nonMember hostState = iota
+
+ // delayingMember is the "'Delaying Member' state, when the host belongs to
+ // the group on the interface and has a report delay timer running for that
+ // membership."
+ //
+ // 'Delaying Listener' is the MLDv1 term used to describe this state.
+ delayingMember
+
+ // idleMember is the "Idle Member" state, when the host belongs to the group
+ // on the interface and does not have a report delay timer running for that
+ // membership.
+ //
+ // 'Idle Listener' is the MLDv1 term used to describe this state.
+ idleMember
+)
+
+// multicastGroupState holds the Generic Multicast Protocol state for a
+// multicast group.
+type multicastGroupState struct {
+ // joins is the number of times the group has been joined.
+ joins uint64
+
+ // state holds the host's state for the group.
+ state hostState
+
+ // lastToSendReport is true if we sent the last report for the group. It is
+ // used to track whether there are other hosts on the subnet that are also
+ // members of the group.
+ //
+ // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page
+ // 8 for MLDv1.
+ lastToSendReport bool
+
+ // delayedReportJob is used to delay sending responses to membership report
+ // messages in order to reduce duplicate reports from multiple hosts on the
+ // interface.
+ //
+ // Must not be nil.
+ delayedReportJob *tcpip.Job
+}
+
+// GenericMulticastProtocolOptions holds options for the generic multicast
+// protocol.
+type GenericMulticastProtocolOptions struct {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled bool
+
+ // Rand is the source of random numbers.
+ Rand *rand.Rand
+
+ // Clock is the clock used to create timers.
+ Clock tcpip.Clock
+
+ // Protocol is the implementation of the variant of multicast group protocol
+ // in use.
+ Protocol MulticastGroupProtocol
+
+ // MaxUnsolicitedReportDelay is the maximum amount of time to wait between
+ // transmitting unsolicited reports.
+ //
+ // Unsolicited reports are transmitted when a group is newly joined.
+ MaxUnsolicitedReportDelay time.Duration
+
+ // AllNodesAddress is a multicast address that all nodes on a network should
+ // be a member of.
+ //
+ // This address will not have the generic multicast protocol performed on it;
+ // it will be left in the non member/listener state, and packets will never
+ // be sent for it.
+ AllNodesAddress tcpip.Address
+}
+
+// MulticastGroupProtocol is a multicast group protocol whose core state machine
+// can be represented by GenericMulticastProtocolState.
+type MulticastGroupProtocol interface {
+ // SendReport sends a multicast report for the specified group address.
+ SendReport(groupAddress tcpip.Address) *tcpip.Error
+
+ // SendLeave sends a multicast leave for the specified group address.
+ SendLeave(groupAddress tcpip.Address) *tcpip.Error
+}
+
+// GenericMulticastProtocolState is the per interface generic multicast protocol
+// state.
+//
+// There is actually no protocol named "Generic Multicast Protocol". Instead,
+// the term used to refer to a generic multicast protocol that applies to both
+// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state
+// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710.
+//
+// GenericMulticastProtocolState.Init MUST be called before calling any of
+// the methods on GenericMulticastProtocolState.
+type GenericMulticastProtocolState struct {
+ opts GenericMulticastProtocolOptions
+
+ mu struct {
+ sync.RWMutex
+
+ // memberships holds group addresses and their associated state.
+ memberships map[tcpip.Address]multicastGroupState
+ }
+}
+
+// Init initializes the Generic Multicast Protocol state.
+//
+// maxUnsolicitedReportDelay is the maximum time between sending unsolicited
+// reports after joining a group.
+func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ g.opts = opts
+ g.mu.memberships = make(map[tcpip.Address]multicastGroupState)
+}
+
+// MakeAllNonMember transitions all groups to the non-member state.
+//
+// The groups will still be considered joined locally.
+func (g *GenericMulticastProtocolState) MakeAllNonMember() {
+ if !g.opts.Enabled {
+ return
+ }
+
+ g.mu.Lock()
+ defer g.mu.Unlock()
+
+ for groupAddress, info := range g.mu.memberships {
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ g.mu.memberships[groupAddress] = info
+ }
+}
+
+// InitializeGroups initializes each group, as if they were newly joined but
+// without affecting the groups' join count.
+//
+// Must only be called after calling MakeAllNonMember as a group should not be
+// initialized while it is not in the non-member state.
+func (g *GenericMulticastProtocolState) InitializeGroups() {
+ if !g.opts.Enabled {
+ return
+ }
+
+ g.mu.Lock()
+ defer g.mu.Unlock()
+
+ for groupAddress, info := range g.mu.memberships {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ g.mu.memberships[groupAddress] = info
+ }
+}
+
+// JoinGroup handles joining a new group.
+//
+// If dontInitialize is true, the group will be not be initialized and will be
+// left in the non-member state - no packets will be sent for it until it is
+// initialized via InitializeGroups.
+func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+
+ if info, ok := g.mu.memberships[groupAddress]; ok {
+ // The group has already been joined.
+ info.joins++
+ g.mu.memberships[groupAddress] = info
+ return
+ }
+
+ info := multicastGroupState{
+ // Since we just joined the group, its count is 1.
+ joins: 1,
+ // The state will be updated below, if required.
+ state: nonMember,
+ lastToSendReport: false,
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() {
+ info, ok := g.mu.memberships[groupAddress]
+ if !ok {
+ panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
+ }
+
+ info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
+ info.state = idleMember
+ g.mu.memberships[groupAddress] = info
+ }),
+ }
+
+ if !dontInitialize && g.opts.Enabled {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ }
+
+ g.mu.memberships[groupAddress] = info
+}
+
+// IsLocallyJoined returns true if the group is locally joined.
+func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool {
+ g.mu.RLock()
+ defer g.mu.RUnlock()
+ _, ok := g.mu.memberships[groupAddress]
+ return ok
+}
+
+// LeaveGroup handles leaving the group.
+//
+// Returns false if the group is not currently joined.
+func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+
+ info, ok := g.mu.memberships[groupAddress]
+ if !ok {
+ return false
+ }
+
+ if info.joins == 0 {
+ panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress))
+ }
+ info.joins--
+ if info.joins != 0 {
+ // If we still have outstanding joins, then do nothing further.
+ g.mu.memberships[groupAddress] = info
+ return true
+ }
+
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ delete(g.mu.memberships, groupAddress)
+ return true
+}
+
+// HandleQuery handles a query message with the specified maximum response time.
+//
+// If the group address is unspecified, then reports will be scheduled for all
+// joined groups.
+//
+// Report(s) will be scheduled to be sent after a random duration between 0 and
+// the maximum response time.
+func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+ if !g.opts.Enabled {
+ return
+ }
+
+ g.mu.Lock()
+ defer g.mu.Unlock()
+
+ // As per RFC 2236 section 2.4 (for IGMPv2),
+ //
+ // In a Membership Query message, the group address field is set to zero
+ // when sending a General Query, and set to the group address being
+ // queried when sending a Group-Specific Query.
+ //
+ // As per RFC 2710 section 3.6 (for MLDv1),
+ //
+ // In a Query message, the Multicast Address field is set to zero when
+ // sending a General Query, and set to a specific IPv6 multicast address
+ // when sending a Multicast-Address-Specific Query.
+ if groupAddress.Unspecified() {
+ // This is a general query as the group address is unspecified.
+ for groupAddress, info := range g.mu.memberships {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.mu.memberships[groupAddress] = info
+ }
+ } else if info, ok := g.mu.memberships[groupAddress]; ok {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.mu.memberships[groupAddress] = info
+ }
+}
+
+// HandleReport handles a report message.
+//
+// If the report is for a joined group, any active delayed report will be
+// cancelled and the host state for the group transitions to idle.
+func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) {
+ if !g.opts.Enabled {
+ return
+ }
+
+ g.mu.Lock()
+ defer g.mu.Unlock()
+
+ // As per RFC 2236 section 3 pages 3-4 (for IGMPv2),
+ //
+ // If the host receives another host's Report (version 1 or 2) while it has
+ // a timer running, it stops its timer for the specified group and does not
+ // send a Report
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // If a node receives another node's Report from an interface for a
+ // multicast address while it has a timer running for that same address
+ // on that interface, it stops its timer and does not send a Report for
+ // that address, thus suppressing duplicate reports on the link.
+ if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember {
+ info.delayedReportJob.Cancel()
+ info.lastToSendReport = false
+ info.state = idleMember
+ g.mu.memberships[groupAddress] = info
+ }
+}
+
+// initializeNewMemberLocked initializes a new group membership.
+//
+// Precondition: g.mu must be locked.
+func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != nonMember {
+ panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state))
+ }
+
+ info.state = idleMember
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a host joins a multicast group, it should immediately transmit an
+ // unsolicited Version 2 Membership Report for that group" ... "it is
+ // recommended that it be repeated".
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node starts listening to a multicast address on an interface,
+ // it should immediately transmit an unsolicited Report for that address
+ // on that interface, in case it is the first listener on the link. To
+ // cover the possibility of the initial Report being lost or damaged, it
+ // is recommended that it be repeated once or twice after short delays
+ // [Unsolicited Report Interval].
+ //
+ // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
+ // unsolicited reports.
+ info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+}
+
+// maybeSendLeave attempts to send a leave message.
+func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
+ if !g.opts.Enabled || !lastToSendReport {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // Okay to ignore the error here as if packet write failed, the multicast
+ // routers will eventually drop our membership anyways. If the interface is
+ // being disabled or removed, the generic multicast protocol's should be
+ // cleared eventually.
+ //
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a router receives a Report, it adds the group being reported to
+ // the list of multicast group memberships on the network on which it
+ // received the Report and sets the timer for the membership to the
+ // [Group Membership Interval]. Repeated Reports refresh the timer. If
+ // no Reports are received for a particular group before this timer has
+ // expired, the router assumes that the group has no local members and
+ // that it need not forward remotely-originated multicasts for that
+ // group onto the attached network.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // When a router receives a Report from a link, if the reported address
+ // is not already present in the router's list of multicast address
+ // having listeners on that link, the reported address is added to the
+ // list, its timer is set to [Multicast Listener Interval], and its
+ // appearance is made known to the router's multicast routing component.
+ // If a Report is received for a multicast address that is already
+ // present in the router's list, the timer for that address is reset to
+ // [Multicast Listener Interval]. If an address's timer expires, it is
+ // assumed that there are no longer any listeners for that address
+ // present on the link, so it is deleted from the list and its
+ // disappearance is made known to the multicast routing component.
+ //
+ // The requirement to send a leave message is also optional (it MAY be
+ // skipped):
+ //
+ // As per RFC 2236 section 6 page 8 (for IGMPv2),
+ //
+ // "send leave" for the group on the interface. If the interface
+ // state says the Querier is running IGMPv1, this action SHOULD be
+ // skipped. If the flag saying we were the last host to report is
+ // cleared, this action MAY be skipped. The Leave Message is sent to
+ // the ALL-ROUTERS group (224.0.0.2).
+ //
+ // As per RFC 2710 section 5 page 8 (for MLDv1),
+ //
+ // "send done" for the address on the interface. If the flag saying
+ // we were the last node to report is cleared, this action MAY be
+ // skipped. The Done message is sent to the link-scope all-routers
+ // address (FF02::2).
+ _ = g.opts.Protocol.SendLeave(groupAddress)
+}
+
+// transitionToNonMemberLocked transitions the given multicast group the the
+// non-member/listener state.
+//
+// Precondition: e.mu must be locked.
+func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state == nonMember {
+ return
+ }
+
+ info.delayedReportJob.Cancel()
+ g.maybeSendLeave(groupAddress, info.lastToSendReport)
+ info.lastToSendReport = false
+ info.state = nonMember
+}
+
+// setDelayTimerForAddressRLocked sets timer to send a delay report.
+//
+// Precondition: g.mu MUST be read locked.
+func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
+ if info.state == nonMember {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // If a timer for the group is already unning, it is reset to the random
+ // value only if the requested Max Response Time is less than the remaining
+ // value of the running timer.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // If a timer for any address is already running, it is reset to the new
+ // random value only if the requested Maximum Response Delay is less than
+ // the remaining value of the running timer.
+ if info.state == delayingMember {
+ // TODO: Reset the timer if time remaining is greater than maxResponseTime.
+ return
+ }
+ info.state = delayingMember
+ info.delayedReportJob.Cancel()
+ info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime))
+}
+
+// calculateDelayTimerDuration returns a random time between (0, maxRespTime].
+func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration {
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // When a host receives a Group-Specific Query, it sets a delay timer to a
+ // random value selected from the range (0, Max Response Time]...
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node receives a Multicast-Address-Specific Query, if it is
+ // listening to the queried Multicast Address on the interface from
+ // which the Query was received, it sets a delay timer for that address
+ // to a random value selected from the range [0, Maximum Response Delay],
+ // as above.
+ if maxRespTime == 0 {
+ return 0
+ }
+ return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime)))
+}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
new file mode 100644
index 000000000..670be30d4
--- /dev/null
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -0,0 +1,576 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+)
+
+const (
+ addr1 = tcpip.Address("\x01")
+ addr2 = tcpip.Address("\x02")
+ addr3 = tcpip.Address("\x03")
+ addr4 = tcpip.Address("\x04")
+
+ maxUnsolicitedReportDelay = time.Second
+)
+
+var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
+
+type mockMulticastGroupProtocol struct {
+ sendReportGroupAddrCount map[tcpip.Address]int
+ sendLeaveGroupAddrCount map[tcpip.Address]int
+}
+
+func (m *mockMulticastGroupProtocol) init() {
+ m.sendReportGroupAddrCount = make(map[tcpip.Address]int)
+ m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+}
+
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+ m.sendReportGroupAddrCount[groupAddress]++
+ return nil
+}
+
+func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ m.sendLeaveGroupAddrCount[groupAddress]++
+ return nil
+}
+
+func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
+ sendReportGroupAddressesMap := make(map[tcpip.Address]int)
+ for _, a := range sendReportGroupAddresses {
+ sendReportGroupAddressesMap[a] = 1
+ }
+
+ sendLeaveGroupAddressesMap := make(map[tcpip.Address]int)
+ for _, a := range sendLeaveGroupAddresses {
+ sendLeaveGroupAddressesMap[a] = 1
+ }
+
+ diff := cmp.Diff(mockMulticastGroupProtocol{
+ sendReportGroupAddrCount: sendReportGroupAddressesMap,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap,
+ }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{}))
+ mgp.init()
+ return diff
+}
+
+func TestJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendReports bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendReports: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendReports: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(0)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ // Joining a group should send a report immediately and another after
+ // a random interval between 0 and the maximum unsolicited report delay.
+ g.JoinGroup(test.addr, false /* dontInitialize */)
+ if test.shouldSendReports {
+ if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendMessages bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendMessages: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendMessages: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(1)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ g.JoinGroup(test.addr, false /* dontInitialize */)
+ if test.shouldSendMessages {
+ if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Leaving a group should send a leave report immediately and cancel any
+ // delayed reports.
+ if !g.LeaveGroup(test.addr) {
+ t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr)
+ }
+ if test.shouldSendMessages {
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleReport(t *testing.T) {
+ tests := []struct {
+ name string
+ reportAddr tcpip.Address
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ reportAddr: "",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ reportAddr: "\x00",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ reportAddr: addr1,
+ expectReportsFor: []tcpip.Address{addr2},
+ },
+ {
+ name: "Specified all-nodes",
+ reportAddr: addr3,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified other",
+ reportAddr: addr4,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(2)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr2, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr3, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report for a group we have a timer scheduled for should
+ // cancel our delayed report timer for the group.
+ g.HandleReport(test.reportAddr)
+ if len(test.expectReportsFor) != 0 {
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleQuery(t *testing.T) {
+ tests := []struct {
+ name string
+ queryAddr tcpip.Address
+ maxDelay time.Duration
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ queryAddr: "",
+ maxDelay: 0,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ queryAddr: "\x00",
+ maxDelay: 1,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ queryAddr: addr1,
+ maxDelay: 2,
+ expectReportsFor: []tcpip.Address{addr1},
+ },
+ {
+ name: "Specified all-nodes",
+ queryAddr: addr3,
+ maxDelay: 3,
+ expectReportsFor: nil,
+ },
+ {
+ name: "Specified other",
+ queryAddr: addr4,
+ maxDelay: 4,
+ expectReportsFor: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr2, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr3, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should make us schedule a new delayed report if it
+ // is a query directed at us or a general query.
+ g.HandleQuery(test.queryAddr, test.maxDelay)
+ if len(test.expectReportsFor) != 0 {
+ clock.Advance(test.maxDelay)
+ if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestJoinCount(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: time.Second,
+ })
+
+ // Set the join count to 2 for a group.
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // Only the first join should trigger a report to be sent.
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Group should still be considered joined after leaving once.
+ if !g.LeaveGroup(addr1) {
+ t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1)
+ }
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // A leave report should only be sent once the join count reaches 0.
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Leaving once more should actually remove us from the group.
+ if !g.LeaveGroup(addr1) {
+ t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1)
+ }
+ if g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Group should no longer be joined so we should not have anything to
+ // leave.
+ if g.LeaveGroup(addr1) {
+ t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1)
+ }
+ if g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestMakeAllNonMemberAndInitialize(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr2, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr3, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should send the leave reports for each but still consider them locally
+ // joined.
+ g.MakeAllNonMember()
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ for _, group := range []tcpip.Address{addr1, addr2, addr3} {
+ if !g.IsLocallyJoined(group) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group)
+ }
+ }
+
+ // Should send the initial set of unsolcited reports.
+ g.InitializeGroups()
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+// TestGroupStateNonMember tests that groups do not send packets when in the
+// non-member state, but are still considered locally joined.
+func TestGroupStateNonMember(t *testing.T) {
+ tests := []struct {
+ name string
+ enabled bool
+ dontInitialize bool
+ }{
+ {
+ name: "Disabled",
+ enabled: false,
+ dontInitialize: false,
+ },
+ {
+ name: "Keep non-member",
+ enabled: true,
+ dontInitialize: true,
+ },
+ {
+ name: "disabled and Keep non-member",
+ enabled: false,
+ dontInitialize: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: test.enabled,
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ g.JoinGroup(addr1, test.dontInitialize)
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ g.JoinGroup(addr2, test.dontInitialize)
+ if !g.IsLocallyJoined(addr2) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ g.HandleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ if !g.LeaveGroup(addr2) {
+ t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2)
+ }
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if g.IsLocallyJoined(addr2) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index d49c44846..a314dd386 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -193,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu
panic("not implemented")
}
-func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
@@ -207,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
panic("not implemented")
}
-func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -223,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
-func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -554,7 +550,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -937,7 +933,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -1093,7 +1089,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
dataBuf := [dataLen]byte{1, 2, 3, 4}
data := dataBuf[:]
- ipv4Options := header.IPv4Options{0, 1, 0, 1}
+ ipv4Options := header.IPv4OptionsSerializer{
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ }
+
+ expectOptions := header.IPv4Options{
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ }
ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
@@ -1243,7 +1251,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding()
+ ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1266,7 +1274,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1276,7 +1284,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1288,7 +1296,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding()))
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
ip.Encode(&header.IPv4Fields{
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
@@ -1307,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1317,7 +1325,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 6252614ec..32f53f217 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "ipv4",
srcs = [
"icmp.go",
+ "igmp.go",
"ipv4.go",
],
visibility = ["//visibility:public"],
@@ -17,6 +18,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -24,7 +26,10 @@ go_library(
go_test(
name = "ipv4_test",
size = "small",
- srcs = ["ipv4_test.go"],
+ srcs = [
+ "igmp_test.go",
+ "ipv4_test.go",
+ ],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 488945226..8e392f86c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -63,7 +63,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
stats := e.protocol.stack.Stats()
- received := stats.ICMP.V4PacketsReceived
+ received := stats.ICMP.V4.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -130,7 +130,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.Echo.Increment()
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if !e.protocol.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return
@@ -379,7 +379,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- sent := p.stack.Stats().ICMP.V4PacketsSent
+ sent := p.stack.Stats().ICMP.V4.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
new file mode 100644
index 000000000..0134fadc0
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -0,0 +1,323 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // igmpV1PresentDefault is the initial state for igmpV1Present in the
+ // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is
+ // the initial state."
+ igmpV1PresentDefault = 0
+
+ // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18
+ // See note on igmpState.igmpV1Present for more detail.
+ v1RouterPresentTimeout = 400 * time.Second
+
+ // v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router
+ // will send General Queries with the Max Response Time set to 0. This MUST
+ // be interpreted as a value of 100 (10 seconds)."
+ //
+ // Note that the Max Response Time field is a value in units of deciseconds.
+ v1MaxRespTime = 10 * time.Second
+
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited IGMP reports.
+ //
+ // Obtained from RFC 2236 Section 8.10, Page 19.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// IGMPOptions holds options for IGMP.
+type IGMPOptions struct {
+ // Enabled indicates whether IGMP will be performed.
+ //
+ // When enabled, IGMP may transmit IGMP report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // IGMP packets.
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
+
+// igmpState is the per-interface IGMP state.
+//
+// igmpState.init() MUST be called after creating an IGMP state.
+type igmpState struct {
+ // The IPv4 endpoint this igmpState is for.
+ ep *endpoint
+ opts IGMPOptions
+
+ // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
+ // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
+ // Membership Reports in response to its Queries, and will not pay
+ // attention to Version 2 Membership Reports. Therefore, a state variable
+ // MUST be kept for each interface, describing whether the multicast
+ // Querier on that interface is running IGMPv1 or IGMPv2. This variable
+ // MUST be based upon whether or not an IGMPv1 query was heard in the last
+ // [Version 1 Router Present Timeout] seconds".
+ //
+ // Must be accessed with atomic operations. Holds a value of 1 when true, 0
+ // when false.
+ igmpV1Present uint32
+
+ mu struct {
+ sync.RWMutex
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+
+ // igmpV1Job is scheduled when this interface receives an IGMPv1 style
+ // message, upon expiration the igmpV1Present flag is cleared.
+ // igmpV1Job may not be nil once igmpState is initialized.
+ igmpV1Job *tcpip.Job
+ }
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+ igmpType := header.IGMPv2MembershipReport
+ if igmp.v1Present() {
+ igmpType = header.IGMPv1MembershipReport
+ }
+ return igmp.writePacket(groupAddress, groupAddress, igmpType)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ // As per RFC 2236 Section 6, Page 8: "If the interface state says the
+ // Querier is running IGMPv1, this action SHOULD be skipped. If the flag
+ // saying we were the last host to report is cleared, this action MAY be
+ // skipped."
+ if igmp.v1Present() {
+ return nil
+ }
+ return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+}
+
+// init sets up an igmpState struct, and is required to be called before using
+// a new igmpState.
+func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+ igmp.ep = ep
+ igmp.opts = opts
+ igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: opts.Enabled,
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: igmp,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv4AllSystems,
+ })
+ igmp.igmpV1Present = igmpV1PresentDefault
+ igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() {
+ igmp.setV1Present(false)
+ })
+}
+
+func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
+ stats := igmp.ep.protocol.stack.Stats()
+ received := stats.IGMP.PacketsReceived
+ headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.IGMP(headerView)
+
+ // Temporarily reset the checksum field to 0 in order to calculate the proper
+ // checksum.
+ wantChecksum := h.Checksum()
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(wantChecksum)
+
+ if gotChecksum != wantChecksum {
+ received.ChecksumErrors.Increment()
+ return
+ }
+
+ switch h.Type() {
+ case header.IGMPMembershipQuery:
+ received.MembershipQuery.Increment()
+ if len(headerView) < header.IGMPQueryMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime())
+ case header.IGMPv1MembershipReport:
+ received.V1MembershipReport.Increment()
+ if len(headerView) < header.IGMPReportMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipReport(h.GroupAddress())
+ case header.IGMPv2MembershipReport:
+ received.V2MembershipReport.Increment()
+ if len(headerView) < header.IGMPReportMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipReport(h.GroupAddress())
+ case header.IGMPLeaveGroup:
+ received.LeaveGroup.Increment()
+ // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or
+ // Report, are ignored in all states"
+
+ default:
+ // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should
+ // be silently ignored. New message types may be used by newer versions of
+ // IGMP, by multicast routing protocols, or other uses."
+ received.Unrecognized.Increment()
+ }
+}
+
+func (igmp *igmpState) v1Present() bool {
+ return atomic.LoadUint32(&igmp.igmpV1Present) == 1
+}
+
+func (igmp *igmpState) setV1Present(v bool) {
+ if v {
+ atomic.StoreUint32(&igmp.igmpV1Present, 1)
+ } else {
+ atomic.StoreUint32(&igmp.igmpV1Present, 0)
+ }
+}
+
+func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+
+ // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
+ // then change the state to note that an IGMPv1 router is present and
+ // schedule the query received Job.
+ if maxRespTime == 0 && igmp.opts.Enabled {
+ igmp.mu.igmpV1Job.Cancel()
+ igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout)
+ igmp.setV1Present(true)
+ maxRespTime = v1MaxRespTime
+ }
+
+ igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime)
+}
+
+func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+ igmp.mu.genericMulticastProtocol.HandleReport(groupAddress)
+}
+
+// writePacket assembles and sends an IGMP packet with the provided fields,
+// incrementing the provided stat counter on success.
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error {
+ igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
+ igmpData.SetType(igmpType)
+ igmpData.SetGroupAddress(groupAddress)
+ igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()),
+ Data: buffer.View(igmpData).ToVectorisedView(),
+ })
+
+ // TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
+ // rather we should select an appropriate local address.
+ localAddr := header.IPv4Any
+ igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.IGMPProtocolNumber,
+ TTL: header.IGMPTTL,
+ TOS: stack.DefaultTOS,
+ }, header.IPv4OptionsSerializer{
+ &header.IPv4SerializableRouterAlertOption{},
+ })
+
+ sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ switch igmpType {
+ case header.IGMPv1MembershipReport:
+ sent.V1MembershipReport.Increment()
+ case header.IGMPv2MembershipReport:
+ sent.V2MembershipReport.Increment()
+ case header.IGMPLeaveGroup:
+ sent.LeaveGroup.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
+ }
+ return nil
+}
+
+// joinGroup handles adding a new group to the membership map, setting up the
+// IGMP state for the group, and sending and scheduling the required
+// messages.
+//
+// If the group already exists in the membership map, returns
+// tcpip.ErrDuplicateAddress.
+func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+ igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+ return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Leave Group message
+// if required.
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+
+ // LeaveGroup returns false only if the group was not joined.
+ if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of IGMP, but remains
+// joined locally.
+func (igmp *igmpState) softLeaveAll() {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+ igmp.mu.genericMulticastProtocol.MakeAllNonMember()
+}
+
+// initializeAll attemps to initialize the IGMP state for each group that has
+// been joined locally.
+func (igmp *igmpState) initializeAll() {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+ igmp.mu.genericMulticastProtocol.InitializeGroups()
+}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
new file mode 100644
index 000000000..5e139377b
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -0,0 +1,156 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
+ nicID = 1
+)
+
+// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set. Raises a t.Error if
+// any field does not match.
+func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(igmpType),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ // Create an endpoint of queue size 1, since no more than 1 packets are ever
+ // queued in the tests in this file.
+ e := channel.New(1, 1280, linkAddr)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ return e, s, clock
+}
+
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: 1,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(igmpType)
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is
+// present on the network. The IGMP stack will then send IGMPv1 Membership
+// reports for backwards compatibility.
+func TestIgmpV1Present(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
+ }
+
+ // This NIC will send an IGMPv2 report immediately, before this test can get
+ // the IGMPv1 General Membership Query in.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Inject an IGMPv1 General Membership Query which is identical to a standard
+ // membership query except the Max Response Time is set to 0, which will tell
+ // the stack that this is a router using IGMPv1. Send it to the all systems
+ // group which is the only group this host belongs to.
+ createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems)
+ if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 {
+ t.Fatalf("got Membership Queries received = %d, want = 1", got)
+ }
+
+ // Before advancing the clock, verify that this host has not sent a
+ // V1MembershipReport yet.
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got)
+ }
+
+ // Verify the solicited Membership Report is sent. Now that this NIC has seen
+ // an IGMPv1 query, it should send an IGMPv1 Membership Report.
+ p, ok = e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt)
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ p, ok = e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1efe6297a..3076185cd 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -72,6 +72,7 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
+ igmp igmpState
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -94,6 +95,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
protocol: p,
}
e.mu.addressableEndpointState.Init(e)
+ e.igmp.init(e, p.options.IGMP)
return e
}
@@ -121,11 +123,22 @@ func (e *endpoint) Enable() *tcpip.Error {
// We have no need for the address endpoint.
ep.DecRef()
+ // Groups may have been joined while the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of IGMP when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.igmp.initializeAll()
+
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
// all-systems multicast group.
- _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
- return err
+ if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err))
+ }
+
+ return nil
}
// Enabled implements stack.NetworkEndpoint.
@@ -162,10 +175,14 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
+ // Leave groups from the perspective of IGMP so that routers know that
+ // we are no longer interested in the group.
+ e.igmp.softLeaveAll()
+
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
@@ -198,37 +215,34 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) {
hdrLen := header.IPv4MinimumSize
- var opts header.IPv4Options
- if params.Options != nil {
- var ok bool
- if opts, ok = params.Options.(header.IPv4Options); !ok {
- panic(fmt.Sprintf("want IPv4Options, got %T", params.Options))
- }
- hdrLen += opts.SizeWithPadding()
- if hdrLen > header.IPv4MaximumHeaderSize {
- // Since we have no way to report an error we must either panic or create
- // a packet which is different to what was requested. Choose panic as this
- // would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize))
- }
+ var optLen int
+ if options != nil {
+ optLen = int(options.Length())
+ }
+ hdrLen += optLen
+ if hdrLen > header.IPv4MaximumHeaderSize {
+ // Since we have no way to report an error we must either panic or create
+ // a packet which is different to what was requested. Choose panic as this
+ // would be a programming error that should be caught in testing.
+ panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize))
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := uint16(pkt.Size())
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
- id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
TotalLength: length,
ID: uint16(id),
TTL: params.TTL,
TOS: params.TOS,
Protocol: uint8(params.Protocol),
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- Options: opts,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ Options: options,
})
ip.SetChecksum(^ip.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -259,7 +273,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -347,7 +361,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
@@ -461,7 +475,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// non-atomic datagrams, so assign an ID to all such datagrams
// according to the definition given in RFC 6864 section 4.
if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
- ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
}
@@ -566,21 +580,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
stats.IP.MalformedPacketsReceived.Increment()
return
}
- srcAddr := h.SourceAddress()
- dstAddr := h.DestinationAddress()
-
- addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
- if addressEndpoint == nil {
- if !e.protocol.Forwarding() {
- stats.IP.InvalidDestinationAddressesReceived.Increment()
- return
- }
-
- _ = e.forwardPacket(pkt)
- return
- }
- subnet := addressEndpoint.AddressWithPrefix().Subnet()
- addressEndpoint.DecRef()
// There has been some confusion regarding verifying checksums. We need
// just look for negative 0 (0xffff) as the checksum, as it's not possible to
@@ -608,16 +607,42 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
// As per RFC 1122 section 3.2.1.3:
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
+ if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+ // Make sure the source address is not a subnet-local broadcast address.
+ if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ if subnet.IsBroadcast(srcAddr) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+ }
- pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+ // The destination address should be an address we own or a group we joined
+ // for us to receive the packet. Otherwise, attempt to forward the packet.
+ if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ addressEndpoint.DecRef()
+ pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+ } else if !e.IsInGroup(dstAddr) {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
@@ -692,6 +717,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
e.handleICMP(pkt)
return
}
+ if p == header.IGMPProtocolNumber {
+ e.igmp.handleIGMP(pkt)
+ return
+ }
if opts := h.Options(); len(opts) != 0 {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
@@ -770,28 +799,12 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
defer e.mu.Unlock()
loopback := e.nic.IsLoopback()
- addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool {
subnet := addressEndpoint.Subnet()
// IPv4 has a notion of a subnet broadcast address and considers the
// loopback interface bound to an address's whole subnet (on linux).
return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
- })
- if addressEndpoint != nil {
- return addressEndpoint
- }
-
- if !allowTemp {
- return nil
- }
-
- addr := localAddr.WithPrefix()
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
- if err != nil {
- // AddAddress only returns an error if the address is already assigned,
- // but we just checked above if the address exists so we expect no error.
- panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
- }
- return addressEndpoint
+ }, allowTemp, tempPEB)
}
// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
@@ -816,28 +829,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV4MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.igmp.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.igmp.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.igmp.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -863,6 +891,8 @@ type protocol struct {
hashIV uint32
fragmentation *fragmentation.Fragmentation
+
+ options Options
}
// Number returns the ipv4 protocol number.
@@ -987,17 +1017,23 @@ func addressToUint32(addr tcpip.Address) uint32 {
return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
}
-// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number and a 32-bit number to
-// generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- a := addressToUint32(r.LocalAddress)
- b := addressToUint32(r.RemoteAddress)
+// hashRoute calculates a hash value for the given source/destination pair using
+// the addresses, transport protocol number and a 32-bit number to generate the
+// hash.
+func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
+ a := addressToUint32(srcAddr)
+ b := addressToUint32(dstAddr)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+// Options holds options to configure a new protocol.
+type Options struct {
+ // IGMP holds options for IGMP.
+ IGMP IGMPOptions
+}
+
+// NewProtocolWithOptions returns an IPv4 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -1007,14 +1043,22 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
}
hashIV := r[buckets]
- p := &protocol{
- stack: s,
- ids: ids,
- hashIV: hashIV,
- defaultTTL: DefaultTTL,
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
+ options: opts,
+ }
+ p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ return p
}
- p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
- return p
+}
+
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
}
func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
@@ -1129,6 +1173,12 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
}
pointer := tsOpt.Pointer()
+ // RFC 791 page 22 states: "The smallest legal value is 5."
+ // Since the pointer is 1 based, and the header is 4 bytes long the
+ // pointer must point beyond the header therefore 4 or less is bad.
+ if pointer <= header.IPv4OptionTimestampHdrLength {
+ return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ }
// To simplify processing below, base further work on the array of timestamps
// beyond the header, rather than on the whole option. Also to aid
// calculations set 'nextSlot' to be 0 based as in the packet it is 1 based.
@@ -1215,7 +1265,15 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
}
- nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based.
+ pointer := rrOpt.Pointer()
+ // RFC 791 page 20 states:
+ // The pointer is relative to this option, and the
+ // smallest legal value for the pointer is 4.
+ // Since the pointer is 1 based, and the header is 3 bytes long the
+ // pointer must point beyond the header therefore 3 or less is bad.
+ if pointer <= header.IPv4OptionRecordRouteHdrLength {
+ return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ }
// RFC 791 page 21 says
// If the route data area is already full (the pointer exceeds the
@@ -1230,14 +1288,14 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// do this (as do most implementations). It is probable that the inclusion
// of these words is a copy/paste error from the timestamp option where
// there are two failure reasons given.
- if nextSlot >= optlen {
+ if pointer > optlen {
return 0, nil
}
// The data area isn't full but there isn't room for a new entry.
// Either Length or Pointer could be bad. We must select Pointer for Linux
- // compatibility, even if only the length is bad.
- if nextSlot+header.IPv4AddressSize > optlen {
+ // compatibility, even if only the length is bad. NB. pointer is 1 based.
+ if pointer+header.IPv4AddressSize > optlen+1 {
if false {
// This is what we would do if we were not being Linux compatible.
// Check for bad pointer or length value. Must be a multiple of 4 after
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 4e4e1f3b4..9e2d2cfd6 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -103,105 +103,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
-// fields when options are supplied.
-func TestIPv4EncodeOptions(t *testing.T) {
- tests := []struct {
- name string
- options header.IPv4Options
- encodedOptions header.IPv4Options // reply should look like this
- wantIHL int
- }{
- {
- name: "valid no options",
- wantIHL: header.IPv4MinimumSize,
- },
- {
- name: "one byte options",
- options: header.IPv4Options{1},
- encodedOptions: header.IPv4Options{1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "two byte options",
- options: header.IPv4Options{1, 1},
- encodedOptions: header.IPv4Options{1, 1, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "three byte options",
- options: header.IPv4Options{1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "four byte options",
- options: header.IPv4Options{1, 1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 1},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "five byte options",
- options: header.IPv4Options{1, 1, 1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 8,
- },
- {
- name: "thirty nine byte options",
- options: header.IPv4Options{
- 1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24,
- 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 35, 36, 37, 38, 39,
- },
- encodedOptions: header.IPv4Options{
- 1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24,
- 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 35, 36, 37, 38, 39, 0,
- },
- wantIHL: header.IPv4MinimumSize + 40,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- paddedOptionLength := test.options.SizeWithPadding()
- ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- totalLen := uint16(ipHeaderLength)
- hdr := buffer.NewPrependable(int(totalLen))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- // To check the padding works, poison the last byte of the options space.
- if paddedOptionLength != len(test.options) {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0xff
- ip.SetHeaderLength(0)
- }
- ip.Encode(&header.IPv4Fields{
- Options: test.options,
- })
- options := ip.Options()
- wantOptions := test.encodedOptions
- if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
- t.Errorf("got IHL of %d, want %d", got, want)
- }
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(wantOptions) == 0 && len(options) == 0 {
- return
- }
-
- if diff := cmp.Diff(wantOptions, options); diff != "" {
- t.Errorf("options mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
func TestForwarding(t *testing.T) {
const (
nicID1 = 1
@@ -453,14 +354,6 @@ func TestIPv4Sanity(t *testing.T) {
replyOptions: header.IPv4Options{1, 1, 0, 0},
},
{
- name: "Check option padding",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{1, 1, 1},
- replyOptions: header.IPv4Options{1, 1, 1, 0},
- },
- {
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
maxTotalLength: ipv4.MaxTotalSize,
@@ -583,7 +476,7 @@ func TestIPv4Sanity(t *testing.T) {
68, 7, 5, 0,
// ^ ^ Linux points here which is wrong.
// | Not a multiple of 4
- 1, 2, 3,
+ 1, 2, 3, 0,
},
shouldFail: true,
expectErrorICMP: true,
@@ -662,6 +555,56 @@ func TestIPv4Sanity(t *testing.T) {
},
},
{
+ // Timestamp pointer uses one based counting so 0 is invalid.
+ name: "timestamp pointer invalid",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, 0, 0x00,
+ // ^ 0 instead of 5 or more.
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Timestamp pointer cannot be less than 5. It must point past the header
+ // which is 4 bytes. (1 based counting)
+ name: "timestamp pointer too small by 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, header.IPv4OptionTimestampHdrLength, 0x00,
+ // ^ header is 4 bytes, so 4 should fail.
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ name: "valid timestamp pointer",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00,
+ // ^ header is 4 bytes, so 5 should succeed.
+ 0, 0, 0, 0,
+ },
+ replyOptions: header.IPv4Options{
+ 68, 8, 9, 0x00,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
// Needs 8 bytes for a type 1 timestamp but there are only 4 free.
name: "bad timer element alignment",
maxTotalLength: ipv4.MaxTotalSize,
@@ -792,7 +735,61 @@ func TestIPv4Sanity(t *testing.T) {
},
},
{
- // Confirm linux bug for bug compatibility.
+ // Pointer uses one based counting so 0 is invalid.
+ name: "record route pointer zero",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 8, 0, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Pointer must be 4 or more as it must point past the 3 byte header
+ // using 1 based counting. 3 should fail.
+ name: "record route pointer too small by 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Pointer must be 4 or more as it must point past the 3 byte header
+ // using 1 based counting. Check 4 passes. (Duplicates "single
+ // record route with room")
+ name: "valid record route pointer",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: header.IPv4Options{
+ 7, 7, 8, // 3 byte header
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Confirm Linux bug for bug compatibility.
// Linux returns slot 22 but the error is in slot 21.
name: "multiple record route with not enough room",
maxTotalLength: ipv4.MaxTotalSize,
@@ -863,8 +860,10 @@ func TestIPv4Sanity(t *testing.T) {
},
})
- paddedOptionLength := test.options.SizeWithPadding()
- ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
+ if len(test.options)%4 != 0 {
+ t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options))
+ }
+ ipHeaderLength := header.IPv4MinimumSize + len(test.options)
if ipHeaderLength > header.IPv4MaximumHeaderSize {
t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
@@ -883,11 +882,6 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
- // To check the padding works, poison the options space.
- if paddedOptionLength != len(test.options) {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0x01
- }
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
@@ -895,10 +889,19 @@ func TestIPv4Sanity(t *testing.T) {
TTL: test.TTL,
SrcAddr: remoteIPv4Addr,
DstAddr: ipv4Addr.Address,
- Options: test.options,
})
if test.headerLength != 0 {
ip.SetHeaderLength(test.headerLength)
+ } else {
+ // Set the calculated header length, since we may manually add options.
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ }
+ if len(test.options) != 0 {
+ // Copy options manually. We do not use Encode for options so we can
+ // verify malformed options with handcrafted payloads.
+ if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
+ t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
+ }
}
ip.SetChecksum(0)
ipHeaderChecksum := ip.CalculateChecksum()
@@ -1003,7 +1006,7 @@ func TestIPv4Sanity(t *testing.T) {
}
// If the IP options change size then the packet will change size, so
// some IP header fields will need to be adjusted for the checks.
- sizeChange := len(test.replyOptions) - paddedOptionLength
+ sizeChange := len(test.replyOptions) - len(test.options)
checker.IPv4(t, replyIPHeader,
checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
@@ -2320,6 +2323,28 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
+ {
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: nil,
+ },
}
for _, test := range tests {
@@ -2513,7 +2538,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2530,7 +2555,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
})
@@ -2644,8 +2669,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != header.IPv4ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
@@ -2687,8 +2712,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != header.IPv4ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
@@ -2736,8 +2761,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != arp.ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress)
}
rep := header.ARP(p.Pkt.NetworkHeader().View())
if got := rep.Op(); got != header.ARPRequest {
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 0ac24a6fb..5e75c8740 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -8,6 +8,7 @@ go_library(
"dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "mld.go",
"ndp.go",
],
visibility = ["//visibility:public"],
@@ -19,6 +20,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -49,3 +51,16 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "ipv6_x_test",
+ size = "small",
+ srcs = ["mld_test.go"],
+ deps = [
+ ":ipv6",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index beb8f562e..510276b8e 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -126,8 +126,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := e.protocol.stack.Stats().ICMP
- sent := stats.V6PacketsSent
- received := stats.V6PacketsReceived
+ sent := stats.V6.PacketsSent
+ received := stats.V6.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
- switch h.Type() {
+ switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
@@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na := header.NDPNeighborAdvert(packet.MessageBody())
// As per RFC 4861 section 7.2.4:
//
@@ -644,8 +644,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
+ case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
+ var handler func(header.MLD)
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ received.MulticastListenerQuery.Increment()
+ handler = e.mld.handleMulticastListenerQuery
+ case header.ICMPv6MulticastListenerReport:
+ received.MulticastListenerReport.Increment()
+ handler = e.mld.handleMulticastListenerReport
+ case header.ICMPv6MulticastListenerDone:
+ received.MulticastListenerDone.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ if handler != nil {
+ handler(header.MLD(payload.ToView()))
+ }
+
default:
- received.Invalid.Increment()
+ received.Unrecognized.Increment()
}
}
@@ -681,12 +704,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns := header.NDPNeighborSolicit(packet.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- stat := p.stack.Stats().ICMP.V6PacketsSent
+ stat := p.stack.Stats().ICMP.V6.PacketsSent
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
@@ -796,7 +819,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
allowResponseToMulticast = reason.respondToMulticast
}
- if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any {
+ isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst)
+ if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any {
return nil
}
@@ -812,8 +836,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
// If we are operating as a router, do not use the packet's destination
// address as the response's source address as we should not own the
// destination address of a packet we are forwarding.
+ //
+ // If the packet was originally destined to a multicast address, then do not
+ // use the packet's destination address as the source for the response ICMP
+ // packet as "multicast addresses must not be used as source addresses in IPv6
+ // packets", as per RFC 4291 section 2.7.
localAddr := origIPHdrDst
- if _, ok := reason.(*icmpReasonHopLimitExceeded); ok {
+ if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast {
localAddr = ""
}
// Even if we were able to receive a packet from some remote, we may not have
@@ -827,7 +856,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
defer route.Release()
stats := p.stack.Stats().ICMP
- sent := stats.V6PacketsSent
+ sent := stats.V6.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 9bc02d851..32adb5c83 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -150,9 +150,9 @@ func (*testInterface) Promiscuous() bool {
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -271,6 +271,22 @@ func TestICMPCounts(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -301,7 +317,7 @@ func TestICMPCounts(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -413,6 +429,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -443,7 +475,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -568,8 +600,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
return
}
- if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
- t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
+ if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr {
+ t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr)
}
// Pull the full payload since network header. Needed for header.IPv6 to
@@ -833,7 +865,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -1028,7 +1060,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1207,7 +1239,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1349,8 +1381,8 @@ func TestLinkAddressRequest(t *testing.T) {
if !ok {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr)
}
if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
@@ -1431,8 +1463,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1473,8 +1505,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1524,8 +1556,8 @@ func TestPacketQueing(t *testing.T) {
t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1543,7 +1575,7 @@ func TestPacketQueing(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
@@ -1592,7 +1624,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1612,7 +1644,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1629,7 +1661,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1645,7 +1677,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1662,7 +1694,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1683,7 +1715,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1702,7 +1734,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1722,7 +1754,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 7a00f6314..8bf84601f 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -34,7 +34,9 @@ import (
)
const (
+ // ReassembleTimeout controls how long a fragment will be held.
// As per RFC 8200 section 4.5:
+ //
// If insufficient fragments are received to complete reassembly of a packet
// within 60 seconds of the reception of the first-arriving fragment of that
// packet, reassembly of that packet must be abandoned.
@@ -84,6 +86,8 @@ type endpoint struct {
addressableEndpointState stack.AddressableEndpointState
ndp ndpState
}
+
+ mld mldState
}
// NICNameFromID is a function that returns a stable name for the specified NIC,
@@ -224,6 +228,12 @@ func (e *endpoint) Enable() *tcpip.Error {
return nil
}
+ // Groups may have been joined when the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of MLD when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mld.initializeAll()
+
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
// and responds to the various NDP messages that are destined to the
@@ -241,8 +251,10 @@ func (e *endpoint) Enable() *tcpip.Error {
// (NDP NS) messages may be sent to the All-Nodes multicast group if the
// source address of the NDP NS is the unspecified address, as per RFC 4861
// section 7.2.4.
- if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
- return err
+ if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err))
}
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
@@ -251,7 +263,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// Addresses may have aleady completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
var err *tcpip.Error
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
addr := addressEndpoint.AddressWithPrefix().Address
if !header.IsV6UnicastAddress(addr) {
return true
@@ -273,7 +285,7 @@ func (e *endpoint) Enable() *tcpip.Error {
}
// Do not auto-generate an IPv6 link-local address for loopback devices.
- if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ if e.protocol.options.AutoGenLinkLocal && !e.nic.IsLoopback() {
// The valid and preferred lifetime is infinite for the auto-generated
// link-local address.
e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
@@ -331,9 +343,13 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
+
+ // Leave groups from the perspective of MLD so that routers know that
+ // we are no longer interested in the group.
+ e.mld.softLeaveAll()
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -341,7 +357,7 @@ func (e *endpoint) disableLocked() {
// Precondition: e.mu must be write locked.
func (e *endpoint) stopDADForPermanentAddressesLocked() {
// Stop DAD for all the tentative unicast addresses.
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if addressEndpoint.GetKind() != stack.PermanentTentative {
return true
}
@@ -376,7 +392,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
length := uint16(pkt.Size())
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -384,8 +400,8 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
NextHeader: uint8(params.Protocol),
HopLimit: params.TTL,
TrafficClass: params.TOS,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
})
pkt.NetworkProtocolNumber = ProtocolNumber
}
@@ -440,7 +456,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -529,7 +545,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r, pb, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params)
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
@@ -737,8 +753,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
- addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
- if addressEndpoint == nil {
+ // The destination address should be an address we own or a group we joined
+ // for us to receive the packet. Otherwise, attempt to forward the packet.
+ if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ } else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
stats.IP.InvalidDestinationAddressesReceived.Increment()
return
@@ -747,7 +766,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
_ = e.forwardPacket(pkt)
return
}
- addressEndpoint.DecRef()
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
@@ -1090,9 +1108,16 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
//
// Which when taken together indicate that an unknown protocol should
// be treated as an unrecognized next header value.
+ // The location of the Next Header field is in a different place in
+ // the initial IPv6 header than it is in the extension headers so
+ // treat it specially.
+ prevHdrIDOffset := uint32(header.IPv6NextHeaderOffset)
+ if previousHeaderStart != 0 {
+ prevHdrIDOffset = previousHeaderStart
+ }
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
- pointer: it.ParseOffset(),
+ pointer: prevHdrIDOffset,
}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
@@ -1100,12 +1125,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
}
default:
- _ = e.protocol.returnError(&icmpReasonParameterProblem{
- code: header.ICMPv6UnknownHeader,
- pointer: it.ParseOffset(),
- }, pkt)
- stats.UnknownProtocolRcvdPackets.Increment()
- return
+ // Since the iterator returns IPv6RawPayloadHeader for unknown Extension
+ // Header IDs this should never happen unless we missed a supported type
+ // here.
+ panic(fmt.Sprintf("unrecognized type from it.Next() = %T", extHdr))
+
}
}
}
@@ -1154,8 +1178,10 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
- return nil, err
+ if err := e.joinGroupLocked(snmc); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
}
addressEndpoint.SetKind(stack.PermanentTentative)
@@ -1211,7 +1237,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ // The endpoint may have already left the multicast group.
+ if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
@@ -1234,7 +1261,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
//
// Precondition: e.mu must be read or write locked.
func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
- return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+ return e.mu.addressableEndpointState.GetAddress(localAddr)
}
// MainAddress implements stack.AddressableEndpoint.
@@ -1285,7 +1312,7 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// Create a candidate set of available addresses we can potentially use as a
// source address.
var cs []addrCandidate
- e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
// If r is not valid for outgoing connections, it is not a valid endpoint.
if !addressEndpoint.IsAssigned(allowExpired) {
return
@@ -1376,28 +1403,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.mld.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mld.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mld.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -1405,7 +1447,8 @@ var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
- stack *stack.Stack
+ stack *stack.Stack
+ options Options
mu struct {
sync.RWMutex
@@ -1429,26 +1472,6 @@ type protocol struct {
forwarding uint32
fragmentation *fragmentation.Fragmentation
-
- // ndpDisp is the NDP event dispatcher that is used to send the netstack
- // integrator NDP related events.
- ndpDisp NDPDispatcher
-
- // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
- ndpConfigs NDPConfigurations
-
- // opaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
- opaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
- // tempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- tempIIDSeed []byte
-
- // autoGenIPv6LinkLocal determines whether or not the stack attempts to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
- autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -1484,13 +1507,14 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
e.mu.addressableEndpointState.Init(e)
e.mu.ndp = ndpState{
ep: e,
- configs: p.ndpConfigs,
+ configs: p.options.NDPConfigs,
dad: make(map[tcpip.Address]dadState),
defaultRouters: make(map[tcpip.Address]defaultRouterState),
onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
}
e.mu.ndp.initializeTempAddrState()
+ e.mld.init(e, p.options.MLD)
p.mu.Lock()
defer p.mu.Unlock()
@@ -1613,17 +1637,17 @@ type Options struct {
// NDPConfigs is the default NDP configurations used by interfaces.
NDPConfigs NDPConfigurations
- // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // AutoGenLinkLocal determines whether or not the stack attempts to
+ // auto-generate a link-local address for newly enabled non-loopback
// NICs.
//
// Note, setting this to true does not mean that a link-local address is
// assigned right away, or at all. If Duplicate Address Detection is enabled,
// an address is only assigned if it successfully resolves. If it fails, no
- // further attempts are made to auto-generate an IPv6 link-local adddress.
+ // further attempts are made to auto-generate a link-local adddress.
//
// The generated link-local address follows RFC 4291 Appendix A guidelines.
- AutoGenIPv6LinkLocal bool
+ AutoGenLinkLocal bool
// NDPDisp is the NDP event dispatcher that an integrator can provide to
// receive NDP related events.
@@ -1647,6 +1671,9 @@ type Options struct {
// seed that is too small would reduce randomness and increase predictability,
// defeating the purpose of temporary SLAAC addresses.
TempIIDSeed []byte
+
+ // MLD holds options for MLD.
+ MLD MLDOptions
}
// NewProtocolWithOptions returns an IPv6 network protocol.
@@ -1658,15 +1685,11 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
- stack: s,
+ stack: s,
+ options: opts,
+
ids: ids,
hashIV: hashIV,
-
- ndpDisp: opts.NDPDisp,
- ndpConfigs: opts.NDPConfigs,
- opaqueIIDOpts: opts.OpaqueIIDOpts,
- tempIIDSeed: opts.TempIIDSeed,
- autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
}
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[*endpoint]struct{})
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index a671d4bac..1c01f17ab 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -51,6 +51,7 @@ const (
fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+ unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier)
extraHeaderReserve = 50
)
@@ -79,7 +80,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
Data: hdr.View().ToVectorisedView(),
}))
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
if got := stats.NeighborAdvert.Value(); got != want {
t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
@@ -573,6 +574,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
expectICMP: false,
},
{
+ name: "unknown next header (first)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, unknownHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6NextHeaderOffset,
+ },
+ {
+ name: "unknown next header (not first)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ unknownHdrID, 0,
+ 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
name: "destination with unknown option skippable action",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
@@ -755,11 +783,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
pointer: header.IPv6FixedHeaderSize,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
- shouldAccept: false,
- },
- {
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
@@ -873,7 +896,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), udpPayload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
+
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength))
sum = header.Checksum(udpPayload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
@@ -884,10 +913,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- dstAddr := tcpip.Address(addr2)
- if test.multicast {
- dstAddr = header.IPv6AllNodesMulticastAddress
- }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr,
@@ -982,9 +1007,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
// the fragment block size of 8 (RFC 8200 section 4.5).
- udpPayload3Length = 127
- udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
- fragmentExtHdrLen = 8
+ udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
+ udpMaximumSizeMinus15 = header.UDPMaximumSize - 15
+ fragmentExtHdrLen = 8
// Note, not all routing extension headers will be 8 bytes but this test
// uses 8 byte routing extension headers for most sub tests.
routingExtHdrLen = 8
@@ -1328,14 +1354,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+65520,
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[:65520],
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
),
},
@@ -1344,14 +1370,17 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
- // Fragment offset = 8190, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+ // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ udpMaximumSizeMinus15 & 0xff,
+ 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[65520:],
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
),
},
@@ -1359,6 +1388,47 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
{
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ (udpMaximumSizeMinus15 & 0xff) + 1,
+ 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
@@ -2439,7 +2509,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2456,7 +2526,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
new file mode 100644
index 000000000..4c06b3f0c
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -0,0 +1,164 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited MLD reports.
+ //
+ // Obtained from RFC 2710 Section 7.10.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// MLDOptions holds options for MLD.
+type MLDOptions struct {
+ // Enabled indicates whether MLD will be performed.
+ //
+ // When enabled, MLD may transmit MLD report and done messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // MLD packets.
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*mldState)(nil)
+
+// mldState is the per-interface MLD state.
+//
+// mldState.init MUST be called to initialize the MLD state.
+type mldState struct {
+ // The IPv6 endpoint this mldState is for.
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+ return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+}
+
+// init sets up an mldState struct, and is required to be called before using
+// a new mldState.
+func (mld *mldState) init(ep *endpoint, opts MLDOptions) {
+ mld.ep = ep
+ mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: opts.Enabled,
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: mld,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv6AllNodesMulticastAddress,
+ })
+}
+
+func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
+}
+
+func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress())
+}
+
+// joinGroup handles joining a new group and sending and scheduling the required
+// messages.
+//
+// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
+ mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
+ return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Done message, if
+// required.
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if mld.genericMulticastProtocol.LeaveGroup(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of MLD, but remains
+// joined locally.
+func (mld *mldState) softLeaveAll() {
+ mld.genericMulticastProtocol.MakeAllNonMember()
+}
+
+// initializeAll attemps to initialize the MLD state for each group that has
+// been joined locally.
+func (mld *mldState) initializeAll() {
+ mld.genericMulticastProtocol.InitializeGroups()
+}
+
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error {
+ sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ var mldStat *tcpip.StatCounter
+ switch mldType {
+ case header.ICMPv6MulticastListenerReport:
+ mldStat = sentStats.MulticastListenerReport
+ case header.ICMPv6MulticastListenerDone:
+ mldStat = sentStats.MulticastListenerDone
+ default:
+ panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
+ }
+
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize))
+ icmp.SetType(mldType)
+ header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
+ // TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
+ // rather we should select an appropriate local address.
+ localAddress := header.IPv6Any
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()),
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+
+ mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.MLDHopLimit,
+ })
+ // TODO(b/162198658): set the ROUTER_ALERT option when sending Host
+ // Membership Reports.
+ if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return err
+ }
+ mldStat.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
new file mode 100644
index 000000000..5677bdd54
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+)
+
+func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // The stack will join an address's solicited node multicast address when
+ // an address is added. An MLD report message should be sent for the
+ // solicited-node group.
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ }
+ {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a report message to be sent")
+ }
+ snmc := header.SolicitedNodeAddr(addr1)
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())),
+ checker.DstAddr(snmc),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(snmc),
+ ),
+ )
+ }
+
+ // The stack will leave an address's solicited node multicast address when
+ // an address is removed. An MLD done message should be sent for the
+ // solicited-node group.
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ }
+ {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a done message to be sent")
+ }
+ snmc := header.SolicitedNodeAddr(addr1)
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(snmc),
+ ),
+ )
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 40da011f8..8cb7d4dab 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -471,17 +471,8 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- rtrSolicit struct {
- // The timer used to send the next router solicitation message.
- timer tcpip.Timer
-
- // Used to let the Router Solicitation timer know that it has been stopped.
- //
- // Must only be read from or written to while protected by the lock of
- // the IPv6 endpoint this ndpState is associated with. MUST be set when the
- // timer is set.
- done *bool
- }
+ // The job used to send the next router solicitation message.
+ rtrSolicitJob *tcpip.Job
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -507,7 +498,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer tcpip.Timer
+ job *tcpip.Job
// Used to let the DAD timer know that it has been stopped.
//
@@ -648,96 +639,70 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
return nil
}
- var done bool
- var timer tcpip.Timer
- // We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the IPv6 endpoint's lock. This is effectively
- // the same as starting a goroutine but we use a timer that fires immediately
- // so we can reset it for the next DAD iteration.
- timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
- ndp.ep.mu.Lock()
- defer ndp.ep.mu.Unlock()
-
- if done {
- // If we reach this point, it means that the DAD timer fired after
- // another goroutine already obtained the IPv6 endpoint lock and stopped
- // DAD before this function obtained the NIC lock. Simply return here and
- // do nothing further.
- return
- }
-
- if addressEndpoint.GetKind() != stack.PermanentTentative {
- // The endpoint should still be marked as tentative since we are still
- // performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
- }
+ state := dadState{
+ job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.dad[addr]
+ if !ok {
+ panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- dadDone := remaining == 0
-
- var err *tcpip.Error
- if !dadDone {
- // Use the unspecified address as the source address when performing DAD.
- addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
-
- // Do not hold the lock when sending packets which may be a long running
- // task or may block link address resolution. We know this is safe
- // because immediately after obtaining the lock again, we check if DAD
- // has been stopped before doing any work with the IPv6 endpoint. Note,
- // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
- // the address was removed.
- ndp.ep.mu.Unlock()
- err = ndp.sendDADPacket(addr, addressEndpoint)
- ndp.ep.mu.Lock()
- addressEndpoint.DecRef()
- }
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- if done {
- // If we reach this point, it means that DAD was stopped after we released
- // the IPv6 endpoint's read lock and before we obtained the write lock.
- return
- }
+ dadDone := remaining == 0
- if dadDone {
- // DAD has resolved.
- addressEndpoint.SetKind(stack.Permanent)
- } else if err == nil {
- // DAD is not done and we had no errors when sending the last NDP NS,
- // schedule the next DAD timer.
- remaining--
- timer.Reset(ndp.configs.RetransmitTimer)
- return
- }
+ var err *tcpip.Error
+ if !dadDone {
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ }
- // At this point we know that either DAD is done or we hit an error sending
- // the last NDP NS. Either way, clean up addr's DAD state and let the
- // integrator know DAD has completed.
- delete(ndp.dad, addr)
+ if dadDone {
+ // DAD has resolved.
+ addressEndpoint.SetKind(stack.Permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
+ remaining--
+ state.job.Schedule(ndp.configs.RetransmitTimer)
+ return
+ }
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
- }
+ // At this point we know that either DAD is done or we hit an error
+ // sending the last NDP NS. Either way, clean up addr's DAD state and let
+ // the integrator know DAD has completed.
+ delete(ndp.dad, addr)
- // If DAD resolved for a stable SLAAC address, attempt generation of a
- // temporary SLAAC address.
- if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
- // Reset the generation attempts counter as we are starting the generation
- // of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
- }
- })
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
+ }
- ndp.dad[addr] = dadState{
- timer: timer,
- done: &done,
+ // If DAD resolved for a stable SLAAC address, attempt generation of a
+ // temporary SLAAC address.
+ if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
+ }),
}
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ state.job.Schedule(0)
+ ndp.dad[addr] = state
+
return nil
}
@@ -745,55 +710,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// addr.
//
// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-//
-// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer r.Release()
-
- // Route should resolve immediately since snmc is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is not
- // locked, the NIC could have been disabled or removed by another goroutine.
- if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
- return err
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
- }
-
- icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
- icmpData.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(addr)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(icmpData).ToVectorisedView(),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
+ Data: buffer.View(icmp).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ })
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
return err
}
sent.NeighborSolicit.Increment()
-
return nil
}
@@ -812,18 +753,11 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
return
}
- if dad.timer != nil {
- dad.timer.Stop()
- dad.timer = nil
-
- *dad.done = true
- dad.done = nil
- }
-
+ dad.job.Cancel()
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
@@ -846,7 +780,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
// only inform the dispatcher on configuration changes. We do nothing else
// with the information.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
var configuration DHCPv6ConfigurationFromNDPRA
switch {
case ra.ManagedAddrConfFlag():
@@ -903,20 +837,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
switch opt := opt.(type) {
case header.NDPRecursiveDNSServer:
- if ndp.ep.protocol.ndpDisp == nil {
+ if ndp.ep.protocol.options.NDPDisp == nil {
continue
}
addrs, _ := opt.Addresses()
- ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.options.NDPDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
- if ndp.ep.protocol.ndpDisp == nil {
+ if ndp.ep.protocol.options.NDPDisp == nil {
continue
}
domainNames, _ := opt.DomainNames()
- ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.options.NDPDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -964,7 +898,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
@@ -976,7 +910,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
@@ -1006,7 +940,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
@@ -1047,7 +981,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1225,7 +1159,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
// Inform the integrator that we have a new SLAAC address.
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return nil
}
@@ -1272,7 +1206,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
}
dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
- if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ if oIID := ndp.ep.protocol.options.OpaqueIIDOpts; oIID.NICNameFromID != nil {
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
@@ -1676,7 +1610,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint
}
addressEndpoint.SetDeprecated(true)
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
@@ -1701,7 +1635,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
@@ -1761,7 +1695,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
@@ -1859,7 +1793,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicit.timer != nil {
+ if ndp.rtrSolicitJob != nil {
// We are already soliciting routers.
return
}
@@ -1876,56 +1810,14 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- var done bool
- ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
- ndp.ep.mu.Lock()
- if done {
- // If we reach this point, it means that the RS timer fired after another
- // goroutine already obtained the IPv6 endpoint lock and stopped
- // solicitations. Simply return here and do nothing further.
- ndp.ep.mu.Unlock()
- return
- }
-
+ ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
- if addressEndpoint == nil {
- // Incase this ends up creating a new temporary address, we need to hold
- // onto the endpoint until a route is obtained. If we decrement the
- // reference count before obtaing a route, the address's resources would
- // be released and attempting to obtain a route after would fail. Once a
- // route is obtainted, it is safe to decrement the reference count since
- // obtaining a route increments the address's reference count.
- addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
- }
- ndp.ep.mu.Unlock()
-
- localAddr := addressEndpoint.AddressWithPrefix().Address
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
- addressEndpoint.DecRef()
- if err != nil {
- return
- }
- defer r.Release()
-
- // Route should resolve immediately since
- // header.IPv6AllRoutersMulticastAddress is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is
- // not locked, the IPv6 endpoint could have been disabled or removed by
- // another goroutine.
- if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
- return
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
+ localAddr := header.IPv6Any
+ if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1936,30 +1828,31 @@ func (ndp *ndpState) startSolicitingRouters() {
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
var optsSerializer header.NDPOptionsSerializer
- if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ linkAddress := ndp.ep.nic.LinkAddress()
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) {
optsSerializer = header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ header.NDPSourceLinkLayerAddressOption(linkAddress),
}
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
icmpData := header.ICMPv6(buffer.NewView(payloadSize))
icmpData.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs := header.NDPRouterSolicit(icmpData.MessageBody())
rs.Options().Serialize(optsSerializer)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ })
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
@@ -1969,21 +1862,12 @@ func (ndp *ndpState) startSolicitingRouters() {
remaining--
}
- ndp.ep.mu.Lock()
- if done || remaining == 0 {
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
- } else if ndp.rtrSolicit.timer != nil {
- // Note, we need to explicitly check to make sure that
- // the timer field is not nil because if it was nil but
- // we still reached this point, then we know the IPv6 endpoint
- // was requested to stop soliciting routers so we don't
- // need to send the next Router Solicitation message.
- ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ if remaining != 0 {
+ ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval)
}
- ndp.ep.mu.Unlock()
})
+ ndp.rtrSolicitJob.Schedule(delay)
}
// stopSolicitingRouters stops soliciting routers. If routers are not currently
@@ -1991,21 +1875,19 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicit.timer == nil {
+ if ndp.rtrSolicitJob == nil {
// Nothing to do.
return
}
- *ndp.rtrSolicit.done = true
- ndp.rtrSolicit.timer.Stop()
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
+ ndp.rtrSolicitJob.Cancel()
+ ndp.rtrSolicitJob = nil
}
// initializeTempAddrState initializes state related to temporary SLAAC
// addresses.
func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 37e8b1083..95c626bb8 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -205,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -220,7 +220,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -311,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -326,16 +326,16 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
neighbors, err := s.Neighbors(nicID)
if err != nil {
@@ -591,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(nicAddr)
opts := ns.Options()
opts.Serialize(test.nsOpts)
@@ -606,7 +606,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
DstAddr: test.nsDst,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) {
if p.Route.RemoteAddress != respNSDst {
t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst)
}
- if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -672,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(test.nsSrc)
@@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) {
if p.Route.RemoteAddress != test.naDst {
t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
}
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -777,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -792,7 +792,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -890,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -905,16 +905,16 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
neighbors, err := s.Neighbors(nicID)
if err != nil {
@@ -1122,7 +1122,7 @@ func TestNDPValidation(t *testing.T) {
s.SetForwarding(ProtocolNumber, true)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -1346,7 +1346,7 @@ func TestRouterAdvertValidation(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
+ copy(pkt.MessageBody(), test.ndpPayload)
payloadLength := hdr.UsedLength()
pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -1358,7 +1358,7 @@ func TestRouterAdvertValidation(t *testing.T) {
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
rxRA := stats.RouterAdvert
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
new file mode 100644
index 000000000..95fb67986
--- /dev/null
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -0,0 +1,1069 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+ ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
+ ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
+ ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
+ ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04")
+ ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05")
+
+ igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
+ igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
+ igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
+ igmpLeaveGroup = uint8(header.IGMPLeaveGroup)
+ mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
+ mldReport = uint8(header.ICMPv6MulticastListenerReport)
+ mldDone = uint8(header.ICMPv6MulticastListenerDone)
+)
+
+var (
+ // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
+ // NIC will wait before sending an unsolicited report after joining a
+ // multicast group, in deciseconds.
+ unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
+ const decisecond = time.Second / 10
+ if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
+ panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
+ }
+ return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
+ }()
+)
+
+// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
+// sent to the provided address with the passed fields set.
+func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv6(t, payload,
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set.
+func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(header.IGMPType(igmpType)),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ // Create an endpoint of queue size 2, since no more than 2 packets are ever
+ // queued in the tests in this file.
+ e := channel.New(2, 1280, linkAddr)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: mgpEnabled,
+ },
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: mgpEnabled,
+ },
+ }),
+ },
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ return e, s, clock
+}
+
+// createAndInjectIGMPPacket creates and injects an IGMP packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: header.IGMPTTL,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(header.IGMPType(igmpType))
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// createAndInjectMLDPacket creates and injects an MLD packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
+ icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize
+ buf := buffer.NewView(header.IPv6MinimumSize + icmpSize)
+
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ HopLimit: header.MLDHopLimit,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ icmp := header.ICMPv6(buf[header.IPv6MinimumSize:])
+ icmp.SetType(header.ICMPv6Type(mldType))
+ mld := header.MLD(icmp.MessageBody())
+ mld.SetMaximumResponseDelay(uint16(maxRespDelay))
+ mld.SetMulticastAddress(groupAddress)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+
+ e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestMGPDisabled tests that the multicast group protocol is not enabled by
+// default.
+func TestMGPDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, false)
+
+ // This NIC may join multicast groups when it is enabled but since MGP is
+ // disabled, no reports should be sent.
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportState.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportState.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
+ }
+
+ // Inject a general query message. This should only trigger a report to be
+ // sent if the MGP was enabled.
+ test.rxQuery(e)
+ if got := test.receivedQueryStat(s).Value(); got != 1 {
+ t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
+ }
+ })
+ }
+}
+
+func TestMGPReceiveCounters(t *testing.T) {
+ tests := []struct {
+ name string
+ headerType uint8
+ maxRespTime byte
+ groupAddress tcpip.Address
+ statCounter func(*stack.Stack) *tcpip.StatCounter
+ rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address)
+ }{
+ {
+ name: "IGMP Membership Query",
+ headerType: igmpMembershipQuery,
+ maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec,
+ groupAddress: header.IPv4Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv1 Membership Report",
+ headerType: igmpv1MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V1MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv2 Membership Report",
+ headerType: igmpv2MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V2MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMP Leave Group",
+ headerType: igmpLeaveGroup,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllRoutersGroup,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.LeaveGroup
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "MLD Query",
+ headerType: mldQuery,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Report",
+ headerType: mldReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Done",
+ headerType: mldDone,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, _ := createStack(t, true)
+
+ test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
+ if got := test.statCounter(s).Value(); got != 1 {
+ t.Fatalf("got %s received = %d, want = 1", test.name, got)
+ }
+ })
+ }
+}
+
+// TestMGPJoinGroup tests that when explicitly joining a multicast group, the
+// stack schedules and sends correct Membership Reports.
+func TestMGPJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ // Test joining a specific address explicitly and verify a Report is sent
+ // immediately.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 1 {
+ t.Errorf("got sentReportState.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Verify the second report is sent by the maximum unsolicited response
+ // interval.
+ p, ok := e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ if got := sentReportStat.Value(); got != 2 {
+ t.Errorf("got sentReportState.Value() = %d, want = 2", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPLeaveGroup tests that when leaving a previously joined multicast
+// group the stack sends a leave/done message.
+func TestMGPLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ validateLeave func(*testing.T, channel.PacketInfo)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := test.sentReportStat(s).Value(); got != 1 {
+ t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving the group should trigger an leave/done message to be sent.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ if got := test.sentLeaveStat(s).Value(); got != 1 {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a leave message to be sent")
+ } else {
+ test.validateLeave(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that a report is sent in response to query
+// messages.
+func TestMGPQueryMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ multicastAddr tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Unspecified",
+ multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
+ expectReport: true,
+ },
+ {
+ name: "Specified",
+ multicastAddr: test.multicastAddr,
+ expectReport: true,
+ },
+ {
+ name: "Specified other address",
+ multicastAddr: func() tcpip.Address {
+ addrBytes := []byte(test.multicastAddr)
+ addrBytes[len(addrBytes)-1]++
+ return tcpip.Address(addrBytes)
+ }(),
+ expectReport: false,
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ for i := uint64(1); i <= 2; i++ {
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != i {
+ t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected %d-th report message to be sent", i)
+ } else {
+ test.validateReport(t, p)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should not send any more packets until a query.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ // Receive a query message which should trigger a report to be sent at
+ // some time before the maximum response time if the report is
+ // targeted at the host.
+ const maxRespTime = 100
+ test.rxQuery(e, maxRespTime, subTest.multicastAddr)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p.Pkt)
+ }
+
+ if subTest.expectReport {
+ clock.Advance(test.maxRespTimeToDuration(maxRespTime))
+ if got := sentReportStat.Value(); got != 3 {
+ t.Errorf("got sentReportState.Value() = %d, want = 3", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that no further reports or leave/done messages
+// are sent after receiving a report.
+func TestMGPReportMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ rxReport func(*channel.Endpoint)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 1 {
+ t.Errorf("got sentReportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Receiving a report for a group we joined should cancel any further
+ // reports.
+ test.rxReport(e)
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 1 {
+ t.Errorf("got sentReportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("sent unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving a group after getting a report should not send a leave/done
+ // message.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ clock.Advance(time.Hour)
+ if got := test.sentLeaveStat(s).Value(); got != 0 {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+func TestMGPWithNICLifecycle(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddrs []tcpip.Address
+ finalMulticastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
+ validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
+ getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
+ finalMulticastAddr: ipv4MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
+ t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
+ }
+ addr := header.IGMP(ipv4.Payload()).GroupAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
+ finalMulticastAddr: ipv6MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, addr, mldReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ if got := tcpip.TransportProtocolNumber(ipv6.NextHeader()); got != header.ICMPv6ProtocolNumber {
+ t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
+ }
+ icmpv6 := header.ICMPv6(ipv6.Payload())
+ if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
+ t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
+ }
+ addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ sentReportStat := test.sentReportStat(s)
+ var reportCounter uint64
+ for _, a := range test.multicastAddrs {
+ if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected a report message to be sent for %s", a)
+ } else {
+ test.validateReport(t, p, a)
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leave messages should be sent for the joined groups when the NIC is
+ // disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ sentLeaveStat := test.sentLeaveStat(s)
+ leaveCounter := uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+
+ test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Reports should be sent for the joined groups when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter += uint64(len(test.multicastAddrs))
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) report message to be sent", i)
+ }
+
+ test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Joining/leaving a group while disabled should not send any messages.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ for i, _ := range test.multicastAddrs {
+ if _, ok := e.Read(); !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+ }
+ for _, a := range test.multicastAddrs {
+ if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
+ }
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
+ }
+ }
+ if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
+ }
+
+ // A report should only be sent for the group we last joined after
+ // enabling the NIC since the original groups were all left.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 7cc52985e..5c3363759 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st
return n, nil
}
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if ep.allowPackets == 0 {
- return ep.err
- }
- ep.allowPackets--
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- ep.WrittenPackets = append(ep.WrittenPackets, pkt)
-
- return nil
-}
-
// Attach implements LinkEndpoint.Attach.
func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}