summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/internal
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-02-09 11:48:08 -0800
committergVisor bot <gvisor-bot@google.com>2021-02-09 11:52:31 -0800
commit18e993eb4f2e6db829acfb5e8725f7d12f73ab67 (patch)
tree2ac310ae5790668143cc3ee9f7ffb086d6431642 /pkg/tcpip/network/internal
parentd0c0549e607699e0186065ad9186431f12260487 (diff)
Move network internal code to internal package
Utilities written to be common across IPv4/IPv6 are not planned to be available for public use. https://golang.org/doc/go1.4#internalpackages PiperOrigin-RevId: 356554862
Diffstat (limited to 'pkg/tcpip/network/internal')
-rw-r--r--pkg/tcpip/network/internal/fragmentation/BUILD54
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation.go339
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_test.go638
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler.go182
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_test.go233
-rw-r--r--pkg/tcpip/network/internal/ip/BUILD17
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go696
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go805
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go100
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD23
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go197
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil_unsafe.go26
12 files changed, 3307 insertions, 3 deletions
diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD
new file mode 100644
index 000000000..274f09092
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/BUILD
@@ -0,0 +1,54 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "reassembler_list",
+ out = "reassembler_list.go",
+ package = "fragmentation",
+ prefix = "reassembler",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*reassembler",
+ "Linker": "*reassembler",
+ },
+)
+
+go_library(
+ name = "fragmentation",
+ srcs = [
+ "fragmentation.go",
+ "reassembler.go",
+ "reassembler_list.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
+ deps = [
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "fragmentation_test",
+ size = "small",
+ srcs = [
+ "fragmentation_test.go",
+ "reassembler_test.go",
+ ],
+ library = ":fragmentation",
+ deps = [
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/network/internal/testutil",
+ "//pkg/tcpip/stack",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
new file mode 100644
index 000000000..243738951
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
@@ -0,0 +1,339 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fragmentation contains the implementation of IP fragmentation.
+// It is based on RFC 791, RFC 815 and RFC 8200.
+package fragmentation
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // HighFragThreshold is the threshold at which we start trimming old
+ // fragmented packets. Linux uses a default value of 4 MB. See
+ // net.ipv4.ipfrag_high_thresh for more information.
+ HighFragThreshold = 4 << 20 // 4MB
+
+ // LowFragThreshold is the threshold we reach to when we start dropping
+ // older fragmented packets. It's important that we keep enough room for newer
+ // packets to be re-assembled. Hence, this needs to be lower than
+ // HighFragThreshold enough. Linux uses a default value of 3 MB. See
+ // net.ipv4.ipfrag_low_thresh for more information.
+ LowFragThreshold = 3 << 20 // 3MB
+
+ // minBlockSize is the minimum block size for fragments.
+ minBlockSize = 1
+)
+
+var (
+ // ErrInvalidArgs indicates to the caller that an invalid argument was
+ // provided.
+ ErrInvalidArgs = errors.New("invalid args")
+
+ // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
+ // with another one.
+ ErrFragmentOverlap = errors.New("overlapping fragments")
+
+ // ErrFragmentConflict indicates that, during reassembly, some fragments are
+ // in conflict with one another.
+ ErrFragmentConflict = errors.New("conflicting fragments")
+)
+
+// FragmentID is the identifier for a fragment.
+type FragmentID struct {
+ // Source is the source address of the fragment.
+ Source tcpip.Address
+
+ // Destination is the destination address of the fragment.
+ Destination tcpip.Address
+
+ // ID is the identification value of the fragment.
+ //
+ // This is a uint32 because IPv6 uses a 32-bit identification value.
+ ID uint32
+
+ // The protocol for the packet.
+ Protocol uint8
+}
+
+// Fragmentation is the main structure that other modules
+// of the stack should use to implement IP Fragmentation.
+type Fragmentation struct {
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[FragmentID]*reassembler
+ rList reassemblerList
+ memSize int
+ timeout time.Duration
+ blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
+ timeoutHandler TimeoutHandler
+}
+
+// TimeoutHandler is consulted if a packet reassembly has timed out.
+type TimeoutHandler interface {
+ // OnReassemblyTimeout will be called with the first fragment (or nil, if the
+ // first fragment has not been received) of a packet whose reassembly has
+ // timed out.
+ OnReassemblyTimeout(pkt *stack.PacketBuffer)
+}
+
+// NewFragmentation creates a new Fragmentation.
+//
+// blockSize specifies the fragment block size, in bytes.
+//
+// highMemoryLimit specifies the limit on the memory consumed
+// by the fragments stored by Fragmentation (overhead of internal data-structures
+// is not accounted). Fragments are dropped when the limit is reached.
+//
+// lowMemoryLimit specifies the limit on which we will reach by dropping
+// fragments after reaching highMemoryLimit.
+//
+// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
+// Fragments are lazily evicted only when a new a packet with an
+// already existing fragmentation-id arrives after the timeout.
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation {
+ if lowMemoryLimit >= highMemoryLimit {
+ lowMemoryLimit = highMemoryLimit
+ }
+
+ if lowMemoryLimit < 0 {
+ lowMemoryLimit = 0
+ }
+
+ if blockSize < minBlockSize {
+ blockSize = minBlockSize
+ }
+
+ f := &Fragmentation{
+ reassemblers: make(map[FragmentID]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ blockSize: blockSize,
+ clock: clock,
+ timeoutHandler: timeoutHandler,
+ }
+ f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
+
+ return f
+}
+
+// Process processes an incoming fragment belonging to an ID and returns a
+// complete packet and its protocol number when all the packets belonging to
+// that ID have been received.
+//
+// [first, last] is the range of the fragment bytes.
+//
+// first must be a multiple of the block size f is configured with. The size
+// of the fragment data must be a multiple of the block size, unless there are
+// no fragments following this fragment (more set to false).
+//
+// proto is the protocol number marked in the fragment being processed. It has
+// to be given here outside of the FragmentID struct because IPv6 should not use
+// the protocol to identify a fragment.
+func (f *Fragmentation) Process(
+ id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (
+ *stack.PacketBuffer, uint8, bool, error) {
+ if first > last {
+ return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
+ }
+
+ if first%f.blockSize != 0 {
+ return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
+ }
+
+ fragmentSize := last - first + 1
+ if more && fragmentSize%f.blockSize != 0 {
+ return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
+ }
+
+ if l := pkt.Data.Size(); l != int(fragmentSize) {
+ return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ }
+
+ f.mu.Lock()
+ r, ok := f.reassemblers[id]
+ if !ok {
+ r = newReassembler(id, f.clock)
+ f.reassemblers[id] = r
+ wasEmpty := f.rList.Empty()
+ f.rList.PushFront(r)
+ if wasEmpty {
+ // If we have just pushed a first reassembler into an empty list, we
+ // should kickstart the release job. The release job will keep
+ // rescheduling itself until the list becomes empty.
+ f.releaseReassemblersLocked()
+ }
+ }
+ f.mu.Unlock()
+
+ resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt)
+ if err != nil {
+ // We probably got an invalid sequence of fragments. Just
+ // discard the reassembler and move on.
+ f.mu.Lock()
+ f.release(r, false /* timedOut */)
+ f.mu.Unlock()
+ return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
+ }
+ f.mu.Lock()
+ f.memSize += memConsumed
+ if done {
+ f.release(r, false /* timedOut */)
+ }
+ // Evict reassemblers if we are consuming more memory than highLimit until
+ // we reach lowLimit.
+ if f.memSize > f.highLimit {
+ for f.memSize > f.lowLimit {
+ tail := f.rList.Back()
+ if tail == nil {
+ break
+ }
+ f.release(tail, false /* timedOut */)
+ }
+ }
+ f.mu.Unlock()
+ return resPkt, firstFragmentProto, done, nil
+}
+
+func (f *Fragmentation) release(r *reassembler, timedOut bool) {
+ // Before releasing a fragment we need to check if r is already marked as done.
+ // Otherwise, we would delete it twice.
+ if r.checkDoneOrMark() {
+ return
+ }
+
+ delete(f.reassemblers, r.id)
+ f.rList.Remove(r)
+ f.memSize -= r.memSize
+ if f.memSize < 0 {
+ log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize)
+ f.memSize = 0
+ }
+
+ if h := f.timeoutHandler; timedOut && h != nil {
+ h.OnReassemblyTimeout(r.pkt)
+ }
+}
+
+// releaseReassemblersLocked releases already-expired reassemblers, then
+// schedules the job to call back itself for the remaining reassemblers if
+// any. This function must be called with f.mu locked.
+func (f *Fragmentation) releaseReassemblersLocked() {
+ now := f.clock.NowMonotonic()
+ for {
+ // The reassembler at the end of the list is the oldest.
+ r := f.rList.Back()
+ if r == nil {
+ // The list is empty.
+ break
+ }
+ elapsed := time.Duration(now-r.creationTime) * time.Nanosecond
+ if f.timeout > elapsed {
+ // If the oldest reassembler has not expired, schedule the release
+ // job so that this function is called back when it has expired.
+ f.releaseJob.Schedule(f.timeout - elapsed)
+ break
+ }
+ // If the oldest reassembler has already expired, release it.
+ f.release(r, true /* timedOut*/)
+ }
+}
+
+// PacketFragmenter is the book-keeping struct for packet fragmentation.
+type PacketFragmenter struct {
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ fragmentPayloadLen int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
+}
+
+// MakePacketFragmenter prepares the struct needed for packet fragmentation.
+//
+// pkt is the packet to be fragmented.
+//
+// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can
+// have.
+//
+// reserve is the number of bytes that should be reserved for the headers in
+// each generated fragment.
+func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter {
+ // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
+ // repeated in each fragment. However we do not currently support any header
+ // of that kind yet, so the following computation is valid for both IPv4 and
+ // IPv6.
+ // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
+ // supported for outbound packets, the fragmentable data should not include
+ // these headers.
+ var fragmentableData buffer.VectorisedView
+ fragmentableData.AppendView(pkt.TransportHeader().View())
+ fragmentableData.Append(pkt.Data)
+ fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen
+
+ return PacketFragmenter{
+ data: fragmentableData,
+ reserve: reserve,
+ fragmentPayloadLen: int(fragmentPayloadLen),
+ fragmentCount: int(fragmentCount),
+ }
+}
+
+// BuildNextFragment returns a packet with the payload of the next fragment,
+// along with the fragment's offset, the number of bytes copied and a boolean
+// indicating if there are more fragments left or not. If this function is
+// called again after it indicated that no more fragments were left, it will
+// panic.
+//
+// Note that the returned packet will not have its network and link headers
+// populated, but space for them will be reserved. The transport header will be
+// stored in the packet's data.
+func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
+ if pf.currentFragment >= pf.fragmentCount {
+ panic("BuildNextFragment should not be called again after the last fragment was returned")
+ }
+
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: pf.reserve,
+ })
+
+ // Copy data for the fragment.
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen)
+
+ offset := pf.fragmentOffset
+ pf.fragmentOffset += copied
+ pf.currentFragment++
+ more := pf.currentFragment != pf.fragmentCount
+
+ return fragPkt, offset, copied, more
+}
+
+// RemainingFragmentCount returns the number of fragments left to be built.
+func (pf *PacketFragmenter) RemainingFragmentCount() int {
+ return pf.fragmentCount - pf.currentFragment
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
new file mode 100644
index 000000000..47ea3173e
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
@@ -0,0 +1,638 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// reassembleTimeout is dummy timeout used for testing, where the clock never
+// advances.
+const reassembleTimeout = 1
+
+// vv is a helper to build VectorisedView from different strings.
+func vv(size int, pieces ...string) buffer.VectorisedView {
+ views := make([]buffer.View, len(pieces))
+ for i, p := range pieces {
+ views[i] = []byte(p)
+ }
+
+ return buffer.NewVectorisedView(size, views)
+}
+
+func pkt(size int, pieces ...string) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv(size, pieces...),
+ })
+}
+
+type processInput struct {
+ id FragmentID
+ first uint16
+ last uint16
+ more bool
+ proto uint8
+ pkt *stack.PacketBuffer
+}
+
+type processOutput struct {
+ vv buffer.VectorisedView
+ proto uint8
+ done bool
+}
+
+var processTestCases = []struct {
+ comment string
+ in []processInput
+ out []processOutput
+}{
+ {
+ comment: "One ID",
+ in: []processInput{
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+ {
+ comment: "Next Header protocol mismatch",
+ in: []processInput{
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "01", "23"), proto: 6, done: true},
+ },
+ },
+ {
+ comment: "Two IDs",
+ in: []processInput{
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")},
+ {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "ab", "cd"), done: true},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+}
+
+func TestFragmentationProcess(t *testing.T) {
+ for _, c := range processTestCases {
+ t.Run(c.comment, func(t *testing.T) {
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
+ firstFragmentProto := c.in[0].proto
+ for i, in := range c.in {
+ resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
+ if err != nil {
+ t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
+ }
+ if done != c.out[i].done {
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
+ in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
+ }
+ if c.out[i].done {
+ if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" {
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, diff)
+ }
+ if firstFragmentProto != proto {
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
+ }
+ if _, ok := f.reassemblers[in.id]; ok {
+ t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
+ }
+ for n := f.rList.Front(); n != nil; n = n.Next() {
+ if n.id == in.id {
+ t.Errorf("Process(%d) did not remove buffer from rList", i)
+ }
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestReassemblingTimeout(t *testing.T) {
+ const (
+ reassemblyTimeout = time.Millisecond
+ protocol = 0xff
+ )
+
+ type fragment struct {
+ first uint16
+ last uint16
+ more bool
+ data string
+ }
+
+ type event struct {
+ // name is a nickname of this event.
+ name string
+
+ // clockAdvance is a duration to advance the clock. The clock advances
+ // before a fragment specified in the fragment field is processed.
+ clockAdvance time.Duration
+
+ // fragment is a fragment to process. This can be nil if there is no
+ // fragment to process.
+ fragment *fragment
+
+ // expectDone is true if the fragmentation instance should report the
+ // reassembly is done after the fragment is processd.
+ expectDone bool
+
+ // memSizeAfterEvent is the expected memory size of the fragmentation
+ // instance after the event.
+ memSizeAfterEvent int
+ }
+
+ memSizeOfFrags := func(frags ...*fragment) int {
+ var size int
+ for _, frag := range frags {
+ size += pkt(len(frag.data), frag.data).MemSize()
+ }
+ return size
+ }
+
+ half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
+ half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
+
+ tests := []struct {
+ name string
+ events []event
+ }{
+ {
+ name: "half1 and half2 are reassembled successfully",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ memSizeAfterEvent: memSizeOfFrags(half1),
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: true,
+ memSizeAfterEvent: 0,
+ },
+ },
+ },
+ {
+ name: "half1 timeout, half2 timeout",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ memSizeAfterEvent: memSizeOfFrags(half1),
+ },
+ {
+ name: "half1 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ memSizeAfterEvent: memSizeOfFrags(half1),
+ },
+ {
+ name: "half1 reassembly timeout",
+ clockAdvance: 1,
+ memSizeAfterEvent: 0,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: false,
+ memSizeAfterEvent: memSizeOfFrags(half2),
+ },
+ {
+ name: "half2 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ memSizeAfterEvent: memSizeOfFrags(half2),
+ },
+ {
+ name: "half2 reassembly timeout",
+ clockAdvance: 1,
+ memSizeAfterEvent: 0,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil)
+ for _, event := range test.events {
+ clock.Advance(event.clockAdvance)
+ if frag := event.fragment; frag != nil {
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data))
+ if err != nil {
+ t.Fatalf("%s: f.Process failed: %s", event.name, err)
+ }
+ if done != event.expectDone {
+ t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
+ }
+ }
+ if got, want := f.memSize, event.memSizeAfterEvent; got != want {
+ t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestMemoryLimits(t *testing.T) {
+ lowLimit := pkt(1, "0").MemSize()
+ highLimit := 3 * lowLimit // Allow at most 3 such packets.
+ f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil)
+ // Send first fragment with id = 0.
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
+ // Send first fragment with id = 1.
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1"))
+ // Send first fragment with id = 2.
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2"))
+
+ // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
+ // evicted.
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3"))
+
+ if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
+ t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok {
+ t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok {
+ t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
+ }
+}
+
+func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
+ memSize := pkt(1, "0").MemSize()
+ f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil)
+ // Send first fragment with id = 0.
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
+ // Send the same packet again.
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
+
+ if got, want := f.memSize, memSize; got != want {
+ t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
+ }
+}
+
+func TestErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize uint16
+ first uint16
+ last uint16
+ more bool
+ data string
+ err error
+ }{
+ {
+ name: "exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: false,
+ data: "01",
+ },
+ {
+ name: "exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "01",
+ },
+ {
+ name: "exact block size with more and extra data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "012",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "exact block size with more and too little data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: false,
+ data: "0",
+ },
+ {
+ name: "first not a multiple of block size",
+ blockSize: 2,
+ first: 3,
+ last: 4,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "first more than last",
+ blockSize: 2,
+ first: 4,
+ last: 3,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil)
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data))
+ if !errors.Is(err, test.err) {
+ t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
+ }
+ if done {
+ t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data)
+ }
+ })
+ }
+}
+
+type fragmentInfo struct {
+ remaining int
+ copied int
+ offset int
+ more bool
+}
+
+func TestPacketFragmenter(t *testing.T) {
+ const (
+ reserve = 60
+ proto = 0
+ )
+
+ tests := []struct {
+ name string
+ fragmentPayloadLen uint32
+ transportHeaderLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ }{
+ {
+ name: "Packet exactly fits in MTU",
+ fragmentPayloadLen: 1280,
+ transportHeaderLen: 0,
+ payloadSize: 1280,
+ wantFragments: []fragmentInfo{
+ {remaining: 0, copied: 1280, offset: 0, more: false},
+ },
+ },
+ {
+ name: "Packet exactly does not fit in MTU",
+ fragmentPayloadLen: 1000,
+ transportHeaderLen: 0,
+ payloadSize: 1001,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 1000, offset: 0, more: true},
+ {remaining: 0, copied: 1, offset: 1000, more: false},
+ },
+ },
+ {
+ name: "Packet has a transport header",
+ fragmentPayloadLen: 560,
+ transportHeaderLen: 40,
+ payloadSize: 560,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 560, offset: 0, more: true},
+ {remaining: 0, copied: 40, offset: 560, more: false},
+ },
+ },
+ {
+ name: "Packet has a huge transport header",
+ fragmentPayloadLen: 500,
+ transportHeaderLen: 1300,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {remaining: 3, copied: 500, offset: 0, more: true},
+ {remaining: 2, copied: 500, offset: 500, more: true},
+ {remaining: 1, copied: 500, offset: 1000, more: true},
+ {remaining: 0, copied: 300, offset: 1500, more: false},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
+ var originalPayload buffer.VectorisedView
+ originalPayload.AppendView(pkt.TransportHeader().View())
+ originalPayload.Append(pkt.Data)
+ var reassembledPayload buffer.VectorisedView
+ pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
+ for i := 0; ; i++ {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ wantFragment := test.wantFragments[i]
+ if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
+ t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
+ }
+ if copied != wantFragment.copied {
+ t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
+ }
+ if offset != wantFragment.offset {
+ t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
+ }
+ if more != wantFragment.more {
+ t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
+ }
+ if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen)
+ }
+ if got := fragPkt.AvailableHeaderBytes(); got != reserve {
+ t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
+ }
+ if got := fragPkt.TransportHeader().View().Size(); got != 0 {
+ t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
+ }
+ reassembledPayload.Append(fragPkt.Data)
+ if !more {
+ if i != len(test.wantFragments)-1 {
+ t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
+ }
+ break
+ }
+ }
+ if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" {
+ t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+type testTimeoutHandler struct {
+ pkt *stack.PacketBuffer
+}
+
+func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ h.pkt = pkt
+}
+
+func TestTimeoutHandler(t *testing.T) {
+ const (
+ proto = 99
+ )
+
+ pk1 := pkt(1, "1")
+ pk2 := pkt(1, "2")
+
+ type processParam struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ }
+
+ tests := []struct {
+ name string
+ params []processParam
+ wantError bool
+ wantPkt *stack.PacketBuffer
+ }{
+ {
+ name: "onTimeout runs",
+ params: []processParam{
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: false,
+ wantPkt: pk1,
+ },
+ {
+ name: "no first fragment",
+ params: []processParam{
+ {
+ first: 1,
+ last: 1,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: false,
+ wantPkt: nil,
+ },
+ {
+ name: "second pkt is ignored",
+ params: []processParam{
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk2,
+ },
+ },
+ wantError: false,
+ wantPkt: pk1,
+ },
+ {
+ name: "invalid args - first is greater than last",
+ params: []processParam{
+ {
+ first: 1,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: true,
+ wantPkt: nil,
+ },
+ }
+
+ id := FragmentID{ID: 0}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ handler := &testTimeoutHandler{pkt: nil}
+
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler)
+
+ for _, p := range test.params {
+ if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError {
+ t.Errorf("f.Process error = %s", err)
+ }
+ }
+ if !test.wantError {
+ r, ok := f.reassemblers[id]
+ if !ok {
+ t.Fatal("Reassembler not found")
+ }
+ f.release(r, true)
+ }
+ switch {
+ case handler.pkt != nil && test.wantPkt == nil:
+ t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView())
+ case handler.pkt == nil && test.wantPkt != nil:
+ t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView())
+ case handler.pkt != nil && test.wantPkt != nil:
+ if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" {
+ t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go
new file mode 100644
index 000000000..933d63d32
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go
@@ -0,0 +1,182 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "math"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type hole struct {
+ first uint16
+ last uint16
+ filled bool
+ final bool
+ // pkt is the fragment packet if hole is filled. We keep the whole pkt rather
+ // than the fragmented payload to prevent binding to specific buffer types.
+ pkt *stack.PacketBuffer
+}
+
+type reassembler struct {
+ reassemblerEntry
+ id FragmentID
+ memSize int
+ proto uint8
+ mu sync.Mutex
+ holes []hole
+ filled int
+ done bool
+ creationTime int64
+ pkt *stack.PacketBuffer
+}
+
+func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
+ r := &reassembler{
+ id: id,
+ creationTime: clock.NowMonotonic(),
+ }
+ r.holes = append(r.holes, hole{
+ first: 0,
+ last: math.MaxUint16,
+ filled: false,
+ final: true,
+ })
+ return r
+}
+
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.done {
+ // A concurrent goroutine might have already reassembled
+ // the packet and emptied the heap while this goroutine
+ // was waiting on the mutex. We don't have to do anything in this case.
+ return nil, 0, false, 0, nil
+ }
+
+ var holeFound bool
+ var memConsumed int
+ for i := range r.holes {
+ currentHole := &r.holes[i]
+
+ if last < currentHole.first || currentHole.last < first {
+ continue
+ }
+ // For IPv6, overlaps with an existing fragment are explicitly forbidden by
+ // RFC 8200 section 4.5:
+ // If any of the fragments being reassembled overlap with any other
+ // fragments being reassembled for the same packet, reassembly of that
+ // packet must be abandoned and all the fragments that have been received
+ // for that packet must be discarded, and no ICMP error messages should be
+ // sent.
+ //
+ // It is not explicitly forbidden for IPv4, but to keep parity with Linux we
+ // disallow it as well:
+ // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
+ if first < currentHole.first || currentHole.last < last {
+ // Incoming fragment only partially fits in the free hole.
+ return nil, 0, false, 0, ErrFragmentOverlap
+ }
+ if !more {
+ if !currentHole.final || currentHole.filled && currentHole.last != last {
+ // We have another final fragment, which does not perfectly overlap.
+ return nil, 0, false, 0, ErrFragmentConflict
+ }
+ }
+
+ holeFound = true
+ if currentHole.filled {
+ // Incoming fragment is a duplicate.
+ continue
+ }
+
+ // We are populating the current hole with the payload and creating a new
+ // hole for any unfilled ranges on either end.
+ if first > currentHole.first {
+ r.holes = append(r.holes, hole{
+ first: currentHole.first,
+ last: first - 1,
+ filled: false,
+ final: false,
+ })
+ }
+ if last < currentHole.last && more {
+ r.holes = append(r.holes, hole{
+ first: last + 1,
+ last: currentHole.last,
+ filled: false,
+ final: currentHole.final,
+ })
+ currentHole.final = false
+ }
+ memConsumed = pkt.MemSize()
+ r.memSize += memConsumed
+ // Update the current hole to precisely match the incoming fragment.
+ r.holes[i] = hole{
+ first: first,
+ last: last,
+ filled: true,
+ final: currentHole.final,
+ pkt: pkt,
+ }
+ r.filled++
+ // For IPv6, it is possible to have different Protocol values between
+ // fragments of a packet (because, unlike IPv4, the Protocol is not used to
+ // identify a fragment). In this case, only the Protocol of the first
+ // fragment must be used as per RFC 8200 Section 4.5.
+ //
+ // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP
+ // options received in the first fragment should be used - and they should
+ // override options from following fragments.
+ if first == 0 {
+ r.pkt = pkt
+ r.proto = proto
+ }
+
+ break
+ }
+ if !holeFound {
+ // Incoming fragment is beyond end.
+ return nil, 0, false, 0, ErrFragmentConflict
+ }
+
+ // Check if all the holes have been filled and we are ready to reassemble.
+ if r.filled < len(r.holes) {
+ return nil, 0, false, memConsumed, nil
+ }
+
+ sort.Slice(r.holes, func(i, j int) bool {
+ return r.holes[i].first < r.holes[j].first
+ })
+
+ resPkt := r.holes[0].pkt
+ for i := 1; i < len(r.holes); i++ {
+ fragPkt := r.holes[i].pkt
+ fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size())
+ }
+ return resPkt, r.proto, true, memConsumed, nil
+}
+
+func (r *reassembler) checkDoneOrMark() bool {
+ r.mu.Lock()
+ prev := r.done
+ r.done = true
+ r.mu.Unlock()
+ return prev
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
new file mode 100644
index 000000000..214a93709
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
@@ -0,0 +1,233 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "bytes"
+ "math"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type processParams struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ wantDone bool
+ wantError error
+}
+
+func TestReassemblerProcess(t *testing.T) {
+ const proto = 99
+
+ v := func(size int) buffer.View {
+ payload := buffer.NewView(size)
+ for i := 1; i < size; i++ {
+ payload[i] = uint8(i) * 3
+ }
+ return payload
+ }
+
+ pkt := func(sizes ...int) *stack.PacketBuffer {
+ var vv buffer.VectorisedView
+ for _, size := range sizes {
+ vv.AppendView(v(size))
+ }
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+ }
+
+ var tests = []struct {
+ name string
+ params []processParams
+ want []hole
+ wantPkt *stack.PacketBuffer
+ }{
+ {
+ name: "No fragments",
+ params: nil,
+ want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
+ },
+ {
+ name: "One fragment at beginning",
+ params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)},
+ {first: 2, last: math.MaxUint16, filled: false, final: true},
+ },
+ },
+ {
+ name: "One fragment in the middle",
+ params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)},
+ {first: 0, last: 0, filled: false, final: false},
+ {first: 3, last: math.MaxUint16, filled: false, final: true},
+ },
+ },
+ {
+ name: "One fragment at the end",
+ params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)},
+ {first: 0, last: 0, filled: false},
+ },
+ },
+ {
+ name: "One fragment completing a packet",
+ params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: true},
+ },
+ wantPkt: pkt(2),
+ },
+ {
+ name: "Two fragments completing a packet",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false},
+ {first: 2, last: 3, filled: true, final: true},
+ },
+ wantPkt: pkt(2, 2),
+ },
+ {
+ name: "Two fragments completing a packet with a duplicate",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false},
+ {first: 2, last: 3, filled: true, final: true},
+ },
+ wantPkt: pkt(2, 2),
+ },
+ {
+ name: "Two fragments completing a packet with a partial duplicate",
+ params: []processParams{
+ {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
+ {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 3, filled: true, final: false},
+ {first: 4, last: 5, filled: true, final: true},
+ },
+ wantPkt: pkt(4, 2),
+ },
+ {
+ name: "Two overlapping fragments",
+ params: []processParams{
+ {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
+ {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
+ },
+ want: []hole{
+ {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)},
+ {first: 11, last: math.MaxUint16, filled: false, final: true},
+ },
+ },
+ {
+ name: "Two final fragments with different ends",
+ params: []processParams{
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)},
+ {first: 0, last: 9, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate, with different ends",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ var resPkt *stack.PacketBuffer
+ var isDone bool
+ for _, param := range test.params {
+ pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
+ if done != param.wantDone || err != param.wantError {
+ t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
+ }
+ if done {
+ resPkt = pkt
+ isDone = true
+ }
+ }
+
+ ignorePkt := func(a, b *stack.PacketBuffer) bool { return true }
+ cmpPktData := func(a, b *stack.PacketBuffer) bool {
+ if a == nil || b == nil {
+ return a == b
+ }
+ return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView())
+ }
+
+ if isDone {
+ if diff := cmp.Diff(
+ test.want, r.holes,
+ cmp.AllowUnexported(hole{}),
+ // Do not compare pkt in hole. Data will be altered.
+ cmp.Comparer(ignorePkt),
+ ); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" {
+ t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff)
+ }
+ } else {
+ if diff := cmp.Diff(
+ test.want, r.holes,
+ cmp.AllowUnexported(hole{}),
+ cmp.Comparer(cmpPktData),
+ ); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD
index 0f55a9770..d21b4c7ef 100644
--- a/pkg/tcpip/network/internal/ip/BUILD
+++ b/pkg/tcpip/network/internal/ip/BUILD
@@ -4,8 +4,16 @@ package(licenses = ["notice"])
go_library(
name = "ip",
- srcs = ["duplicate_address_detection.go"],
- visibility = ["//visibility:public"],
+ srcs = [
+ "duplicate_address_detection.go",
+ "generic_multicast_protocol.go",
+ "stats.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/network/arp:__pkg__",
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
deps = [
"//pkg/sync",
"//pkg/tcpip",
@@ -16,7 +24,10 @@ go_library(
go_test(
name = "ip_x_test",
size = "small",
- srcs = ["duplicate_address_detection_test.go"],
+ srcs = [
+ "duplicate_address_detection_test.go",
+ "generic_multicast_protocol_test.go",
+ ],
deps = [
":ip",
"//pkg/sync",
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
new file mode 100644
index 000000000..b9f129728
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -0,0 +1,696 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ip holds IPv4/IPv6 common utilities.
+package ip
+
+import (
+ "fmt"
+ "math/rand"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// hostState is the state a host may be in for a multicast group.
+type hostState int
+
+// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1
+// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
+// and MLDv1, IGMPv2 terminology will be used.
+//
+// ______________receive query______________
+// | |
+// | _____send or receive report_____ |
+// | | | |
+// V | V |
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+ |
+// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | -
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+
+// | ^ | ^ | ^ | ^
+// | | | | | | | |
+// ---------- ------- ---------- -------------
+// initialize new send inital fail to send send or receive
+// group membership report delayed report report
+//
+// Not shown in the diagram above, but any state may transition into the non
+// member state when a group is left.
+const (
+ // nonMember is the "'Non-Member' state, when the host does not belong to the
+ // group on the interface. This is the initial state for all memberships on
+ // all network interfaces; it requires no storage in the host."
+ //
+ // 'Non-Listener' is the MLDv1 term used to describe this state.
+ //
+ // This state is used to keep track of groups that have been joined locally,
+ // but without advertising the membership to the network.
+ nonMember hostState = iota
+
+ // pendingMember is a newly joined member that is waiting to successfully send
+ // the initial set of reports.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the initial report needs to be sent.
+ //
+ // MAY NOT transition to the idle member state from this state.
+ pendingMember
+
+ // delayingMember is the "'Delaying Member' state, when the host belongs to
+ // the group on the interface and has a report delay timer running for that
+ // membership."
+ //
+ // 'Delaying Listener' is the MLDv1 term used to describe this state.
+ delayingMember
+
+ // queuedDelayingMember is a delayingMember that failed to send a report after
+ // its delayed report timer fired. Hosts in this state are waiting to attempt
+ // retransmission of the delayed report.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the delayed report needs to be sent.
+ //
+ // May transition to idle member if a report is received for a group.
+ queuedDelayingMember
+
+ // idleMember is the "Idle Member" state, when the host belongs to the group
+ // on the interface and does not have a report delay timer running for that
+ // membership.
+ //
+ // 'Idle Listener' is the MLDv1 term used to describe this state.
+ idleMember
+)
+
+func (s hostState) isDelayingMember() bool {
+ switch s {
+ case nonMember, pendingMember, idleMember:
+ return false
+ case delayingMember, queuedDelayingMember:
+ return true
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", s))
+ }
+}
+
+// multicastGroupState holds the Generic Multicast Protocol state for a
+// multicast group.
+type multicastGroupState struct {
+ // joins is the number of times the group has been joined.
+ joins uint64
+
+ // state holds the host's state for the group.
+ state hostState
+
+ // lastToSendReport is true if we sent the last report for the group. It is
+ // used to track whether there are other hosts on the subnet that are also
+ // members of the group.
+ //
+ // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page
+ // 8 for MLDv1.
+ lastToSendReport bool
+
+ // delayedReportJob is used to delay sending responses to membership report
+ // messages in order to reduce duplicate reports from multiple hosts on the
+ // interface.
+ //
+ // Must not be nil.
+ delayedReportJob *tcpip.Job
+
+ // delyedReportJobFiresAt is the time when the delayed report job will fire.
+ //
+ // A zero value indicates that the job is not scheduled.
+ delayedReportJobFiresAt time.Time
+}
+
+func (m *multicastGroupState) cancelDelayedReportJob() {
+ m.delayedReportJob.Cancel()
+ m.delayedReportJobFiresAt = time.Time{}
+}
+
+// GenericMulticastProtocolOptions holds options for the generic multicast
+// protocol.
+type GenericMulticastProtocolOptions struct {
+ // Rand is the source of random numbers.
+ Rand *rand.Rand
+
+ // Clock is the clock used to create timers.
+ Clock tcpip.Clock
+
+ // Protocol is the implementation of the variant of multicast group protocol
+ // in use.
+ Protocol MulticastGroupProtocol
+
+ // MaxUnsolicitedReportDelay is the maximum amount of time to wait between
+ // transmitting unsolicited reports.
+ //
+ // Unsolicited reports are transmitted when a group is newly joined.
+ MaxUnsolicitedReportDelay time.Duration
+
+ // AllNodesAddress is a multicast address that all nodes on a network should
+ // be a member of.
+ //
+ // This address will not have the generic multicast protocol performed on it;
+ // it will be left in the non member/listener state, and packets will never
+ // be sent for it.
+ AllNodesAddress tcpip.Address
+}
+
+// MulticastGroupProtocol is a multicast group protocol whose core state machine
+// can be represented by GenericMulticastProtocolState.
+type MulticastGroupProtocol interface {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled() bool
+
+ // SendReport sends a multicast report for the specified group address.
+ //
+ // Returns false if the caller should queue the report to be sent later. Note,
+ // returning false does not mean that the receiver hit an error.
+ SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error)
+
+ // SendLeave sends a multicast leave for the specified group address.
+ SendLeave(groupAddress tcpip.Address) tcpip.Error
+}
+
+// GenericMulticastProtocolState is the per interface generic multicast protocol
+// state.
+//
+// There is actually no protocol named "Generic Multicast Protocol". Instead,
+// the term used to refer to a generic multicast protocol that applies to both
+// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state
+// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710.
+//
+// Callers must synchronize accesses to the generic multicast protocol state;
+// GenericMulticastProtocolState obtains no locks in any of its methods. The
+// only exception to this is GenericMulticastProtocolState's timer/job callbacks
+// which will obtain the lock provided to the GenericMulticastProtocolState when
+// it is initialized.
+//
+// GenericMulticastProtocolState.Init MUST be called before calling any of
+// the methods on GenericMulticastProtocolState.
+//
+// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the
+// multicast group protocol is disabled so that leave messages may be sent.
+type GenericMulticastProtocolState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
+ opts GenericMulticastProtocolOptions
+
+ // memberships holds group addresses and their associated state.
+ memberships map[tcpip.Address]multicastGroupState
+
+ // protocolMU is the mutex used to protect the protocol.
+ protocolMU *sync.RWMutex
+}
+
+// Init initializes the Generic Multicast Protocol state.
+//
+// Must only be called once for the lifetime of g; Init will panic if it is
+// called twice.
+//
+// The GenericMulticastProtocolState will only grab the lock when timers/jobs
+// fire.
+//
+// Note: the methods on opts.Protocol will always be called while protocolMU is
+// held.
+func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) {
+ if g.memberships != nil {
+ panic("attempted to initialize generic membership protocol state twice")
+ }
+
+ *g = GenericMulticastProtocolState{
+ opts: opts,
+ memberships: make(map[tcpip.Address]multicastGroupState),
+ protocolMU: protocolMU,
+ }
+}
+
+// MakeAllNonMemberLocked transitions all groups to the non-member state.
+//
+// The groups will still be considered joined locally.
+//
+// MUST be called when the multicast group protocol is disabled.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// InitializeGroupsLocked initializes each group, as if they were newly joined
+// but without affecting the groups' join count.
+//
+// Must only be called after calling MakeAllNonMember as a group should not be
+// initialized while it is not in the non-member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// SendQueuedReportsLocked attempts to send reports for groups that failed to
+// send reports during their last attempt.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
+ for groupAddress, info := range g.memberships {
+ switch info.state {
+ case nonMember, delayingMember, idleMember:
+ case pendingMember:
+ // pendingMembers failed to send their initial unsolicited report so try
+ // to send the report and queue the extra unsolicited reports.
+ g.maybeSendInitialReportLocked(groupAddress, &info)
+ case queuedDelayingMember:
+ // queuedDelayingMembers failed to send their delayed reports so try to
+ // send the report and transition them to the idle state.
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", info.state))
+ }
+ g.memberships[groupAddress] = info
+ }
+}
+
+// JoinGroupLocked handles joining a new group.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) {
+ if info, ok := g.memberships[groupAddress]; ok {
+ // The group has already been joined.
+ info.joins++
+ g.memberships[groupAddress] = info
+ return
+ }
+
+ info := multicastGroupState{
+ // Since we just joined the group, its count is 1.
+ joins: 1,
+ // The state will be updated below, if required.
+ state: nonMember,
+ lastToSendReport: false,
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ if !g.opts.Protocol.Enabled() {
+ panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress))
+ }
+
+ info, ok := g.memberships[groupAddress]
+ if !ok {
+ panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
+ }
+
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }),
+ }
+
+ if g.opts.Protocol.Enabled() {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ }
+
+ g.memberships[groupAddress] = info
+}
+
+// IsLocallyJoinedRLocked returns true if the group is locally joined.
+//
+// Precondition: g.protocolMU must be read locked.
+func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool {
+ _, ok := g.memberships[groupAddress]
+ return ok
+}
+
+// LeaveGroupLocked handles leaving the group.
+//
+// Returns false if the group is not currently joined.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool {
+ info, ok := g.memberships[groupAddress]
+ if !ok {
+ return false
+ }
+
+ if info.joins == 0 {
+ panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress))
+ }
+ info.joins--
+ if info.joins != 0 {
+ // If we still have outstanding joins, then do nothing further.
+ g.memberships[groupAddress] = info
+ return true
+ }
+
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ delete(g.memberships, groupAddress)
+ return true
+}
+
+// HandleQueryLocked handles a query message with the specified maximum response
+// time.
+//
+// If the group address is unspecified, then reports will be scheduled for all
+// joined groups.
+//
+// Report(s) will be scheduled to be sent after a random duration between 0 and
+// the maximum response time.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ // As per RFC 2236 section 2.4 (for IGMPv2),
+ //
+ // In a Membership Query message, the group address field is set to zero
+ // when sending a General Query, and set to the group address being
+ // queried when sending a Group-Specific Query.
+ //
+ // As per RFC 2710 section 3.6 (for MLDv1),
+ //
+ // In a Query message, the Multicast Address field is set to zero when
+ // sending a General Query, and set to a specific IPv6 multicast address
+ // when sending a Multicast-Address-Specific Query.
+ if groupAddress.Unspecified() {
+ // This is a general query as the group address is unspecified.
+ for groupAddress, info := range g.memberships {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.memberships[groupAddress] = info
+ }
+ } else if info, ok := g.memberships[groupAddress]; ok {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// HandleReportLocked handles a report message.
+//
+// If the report is for a joined group, any active delayed report will be
+// cancelled and the host state for the group transitions to idle.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ // As per RFC 2236 section 3 pages 3-4 (for IGMPv2),
+ //
+ // If the host receives another host's Report (version 1 or 2) while it has
+ // a timer running, it stops its timer for the specified group and does not
+ // send a Report
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // If a node receives another node's Report from an interface for a
+ // multicast address while it has a timer running for that same address
+ // on that interface, it stops its timer and does not send a Report for
+ // that address, thus suppressing duplicate reports on the link.
+ if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() {
+ info.cancelDelayedReportJob()
+ info.lastToSendReport = false
+ info.state = idleMember
+ g.memberships[groupAddress] = info
+ }
+}
+
+// initializeNewMemberLocked initializes a new group membership.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != nonMember {
+ panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ info.lastToSendReport = false
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ info.state = idleMember
+ return
+ }
+
+ info.state = pendingMember
+ g.maybeSendInitialReportLocked(groupAddress, info)
+}
+
+// maybeSendInitialReportLocked attempts to start transmission of the initial
+// set of reports after newly joining a group.
+//
+// Host must be in pending member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != pendingMember {
+ panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a host joins a multicast group, it should immediately transmit an
+ // unsolicited Version 2 Membership Report for that group" ... "it is
+ // recommended that it be repeated".
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node starts listening to a multicast address on an interface,
+ // it should immediately transmit an unsolicited Report for that address
+ // on that interface, in case it is the first listener on the link. To
+ // cover the possibility of the initial Report being lost or damaged, it
+ // is recommended that it be repeated once or twice after short delays
+ // [Unsolicited Report Interval].
+ //
+ // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
+ // unsolicited reports.
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ }
+}
+
+// maybeSendDelayedReportLocked attempts to send the delayed report.
+//
+// Host must be in pending, delaying or queued delaying member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if !info.state.isDelayingMember() {
+ panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ info.state = idleMember
+ } else {
+ info.state = queuedDelayingMember
+ }
+}
+
+// maybeSendLeave attempts to send a leave message.
+func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
+ if !g.opts.Protocol.Enabled() || !lastToSendReport {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // Okay to ignore the error here as if packet write failed, the multicast
+ // routers will eventually drop our membership anyways. If the interface is
+ // being disabled or removed, the generic multicast protocol's should be
+ // cleared eventually.
+ //
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a router receives a Report, it adds the group being reported to
+ // the list of multicast group memberships on the network on which it
+ // received the Report and sets the timer for the membership to the
+ // [Group Membership Interval]. Repeated Reports refresh the timer. If
+ // no Reports are received for a particular group before this timer has
+ // expired, the router assumes that the group has no local members and
+ // that it need not forward remotely-originated multicasts for that
+ // group onto the attached network.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // When a router receives a Report from a link, if the reported address
+ // is not already present in the router's list of multicast address
+ // having listeners on that link, the reported address is added to the
+ // list, its timer is set to [Multicast Listener Interval], and its
+ // appearance is made known to the router's multicast routing component.
+ // If a Report is received for a multicast address that is already
+ // present in the router's list, the timer for that address is reset to
+ // [Multicast Listener Interval]. If an address's timer expires, it is
+ // assumed that there are no longer any listeners for that address
+ // present on the link, so it is deleted from the list and its
+ // disappearance is made known to the multicast routing component.
+ //
+ // The requirement to send a leave message is also optional (it MAY be
+ // skipped):
+ //
+ // As per RFC 2236 section 6 page 8 (for IGMPv2),
+ //
+ // "send leave" for the group on the interface. If the interface
+ // state says the Querier is running IGMPv1, this action SHOULD be
+ // skipped. If the flag saying we were the last host to report is
+ // cleared, this action MAY be skipped. The Leave Message is sent to
+ // the ALL-ROUTERS group (224.0.0.2).
+ //
+ // As per RFC 2710 section 5 page 8 (for MLDv1),
+ //
+ // "send done" for the address on the interface. If the flag saying
+ // we were the last node to report is cleared, this action MAY be
+ // skipped. The Done message is sent to the link-scope all-routers
+ // address (FF02::2).
+ _ = g.opts.Protocol.SendLeave(groupAddress)
+}
+
+// transitionToNonMemberLocked transitions the given multicast group the the
+// non-member/listener state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state == nonMember {
+ return
+ }
+
+ info.cancelDelayedReportJob()
+ g.maybeSendLeave(groupAddress, info.lastToSendReport)
+ info.lastToSendReport = false
+ info.state = nonMember
+}
+
+// setDelayTimerForAddressRLocked sets timer to send a delay report.
+//
+// Precondition: g.protocolMU MUST be read locked.
+func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
+ if info.state == nonMember {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // If a timer for the group is already unning, it is reset to the random
+ // value only if the requested Max Response Time is less than the remaining
+ // value of the running timer.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // If a timer for any address is already running, it is reset to the new
+ // random value only if the requested Maximum Response Delay is less than
+ // the remaining value of the running timer.
+ now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds())
+ if info.state == delayingMember {
+ if info.delayedReportJobFiresAt.IsZero() {
+ panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress))
+ }
+
+ if info.delayedReportJobFiresAt.Sub(now) <= maxResponseTime {
+ // The timer is scheduled to fire before the maximum response time so we
+ // leave our timer as is.
+ return
+ }
+ }
+
+ info.state = delayingMember
+ info.cancelDelayedReportJob()
+ maxResponseTime = g.calculateDelayTimerDuration(maxResponseTime)
+ info.delayedReportJob.Schedule(maxResponseTime)
+ info.delayedReportJobFiresAt = now.Add(maxResponseTime)
+}
+
+// calculateDelayTimerDuration returns a random time between (0, maxRespTime].
+func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration {
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // When a host receives a Group-Specific Query, it sets a delay timer to a
+ // random value selected from the range (0, Max Response Time]...
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node receives a Multicast-Address-Specific Query, if it is
+ // listening to the queried Multicast Address on the interface from
+ // which the Query was received, it sets a delay timer for that address
+ // to a random value selected from the range [0, Maximum Response Delay],
+ // as above.
+ if maxRespTime == 0 {
+ return 0
+ }
+ return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime)))
+}
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
new file mode 100644
index 000000000..381460c82
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -0,0 +1,805 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
+)
+
+const maxUnsolicitedReportDelay = time.Second
+
+var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
+
+type mockMulticastGroupProtocolProtectedFields struct {
+ sync.RWMutex
+
+ genericMulticastGroup ip.GenericMulticastProtocolState
+ sendReportGroupAddrCount map[tcpip.Address]int
+ sendLeaveGroupAddrCount map[tcpip.Address]int
+ makeQueuePackets bool
+ disabled bool
+}
+
+type mockMulticastGroupProtocol struct {
+ t *testing.T
+
+ mu mockMulticastGroupProtocolProtectedFields
+}
+
+func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.initLocked()
+ opts.Protocol = m
+ m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
+}
+
+func (m *mockMulticastGroupProtocol) initLocked() {
+ m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
+ m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+}
+
+func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.disabled = !v
+}
+
+func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.makeQueuePackets = v
+}
+
+func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.JoinGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleReportLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
+}
+
+func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) makeAllNonMember() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
+}
+
+func (m *mockMulticastGroupProtocol) initializeGroups() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.InitializeGroupsLocked()
+}
+
+func (m *mockMulticastGroupProtocol) sendQueuedReports() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.SendQueuedReportsLocked()
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be read locked.
+func (m *mockMulticastGroupProtocol) Enabled() bool {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
+ }
+
+ return !m.mu.disabled
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+
+ m.mu.sendReportGroupAddrCount[groupAddress]++
+ return !m.mu.makeQueuePackets, nil
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+
+ m.mu.sendLeaveGroupAddrCount[groupAddress]++
+ return nil
+}
+
+func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ sendReportGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendReportGroupAddresses {
+ sendReportGroupAddrCount[a] = 1
+ }
+
+ sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendLeaveGroupAddresses {
+ sendLeaveGroupAddrCount[a] = 1
+ }
+
+ diff := cmp.Diff(
+ &mockMulticastGroupProtocol{
+ mu: mockMulticastGroupProtocolProtectedFields{
+ sendReportGroupAddrCount: sendReportGroupAddrCount,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ },
+ },
+ m,
+ cmp.AllowUnexported(mockMulticastGroupProtocol{}),
+ cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
+ // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
+ cmp.FilterPath(
+ func(p cmp.Path) bool {
+ switch p.Last().String() {
+ case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup":
+ return true
+ }
+ return false
+ },
+ cmp.Ignore(),
+ ),
+ )
+ m.initLocked()
+ return diff
+}
+
+func TestJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendReports bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendReports: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendReports: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(0)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ // Joining a group should send a report immediately and another after
+ // a random interval between 0 and the maximum unsolicited report delay.
+ mgp.joinGroup(test.addr)
+ if test.shouldSendReports {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendMessages bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendMessages: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendMessages: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(1)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ mgp.joinGroup(test.addr)
+ if test.shouldSendMessages {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Leaving a group should send a leave report immediately and cancel any
+ // delayed reports.
+ {
+
+ if !mgp.leaveGroup(test.addr) {
+ t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
+ }
+ }
+ if test.shouldSendMessages {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleReport(t *testing.T) {
+ tests := []struct {
+ name string
+ reportAddr tcpip.Address
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ reportAddr: "",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ reportAddr: "\x00",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ reportAddr: addr1,
+ expectReportsFor: []tcpip.Address{addr2},
+ },
+ {
+ name: "Specified all-nodes",
+ reportAddr: addr3,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified other",
+ reportAddr: addr4,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(2)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report for a group we have a timer scheduled for should
+ // cancel our delayed report timer for the group.
+ mgp.handleReport(test.reportAddr)
+ if len(test.expectReportsFor) != 0 {
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleQuery(t *testing.T) {
+ tests := []struct {
+ name string
+ queryAddr tcpip.Address
+ maxDelay time.Duration
+ expectQueriedReportsFor []tcpip.Address
+ expectDelayedReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ queryAddr: "",
+ maxDelay: 0,
+ expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
+ expectDelayedReportsFor: nil,
+ },
+ {
+ name: "Unpecified any",
+ queryAddr: "\x00",
+ maxDelay: 1,
+ expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
+ expectDelayedReportsFor: nil,
+ },
+ {
+ name: "Specified",
+ queryAddr: addr1,
+ maxDelay: 2,
+ expectQueriedReportsFor: []tcpip.Address{addr1},
+ expectDelayedReportsFor: []tcpip.Address{addr2},
+ },
+ {
+ name: "Specified all-nodes",
+ queryAddr: addr3,
+ maxDelay: 3,
+ expectQueriedReportsFor: nil,
+ expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified other",
+ queryAddr: addr4,
+ maxDelay: 4,
+ expectQueriedReportsFor: nil,
+ expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should make us reschedule our delayed report timer
+ // to some time within the new max response delay.
+ mgp.handleQuery(test.queryAddr, test.maxDelay)
+ clock.Advance(test.maxDelay)
+ if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The groups that were not affected by the query should still send a
+ // report after the max unsolicited report delay.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestJoinCount(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: time.Second,
+ })
+
+ // Set the join count to 2 for a group.
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // Only the first join should trigger a report to be sent.
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should still be considered joined after leaving once.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
+ }
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // A leave report should only be sent once the join count reaches 0.
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving once more should actually remove us from the group.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should no longer be joined so we should not have anything to
+ // leave.
+ if mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestMakeAllNonMemberAndInitialize(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should send the leave reports for each but still consider them locally
+ // joined.
+ mgp.makeAllNonMember()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ for _, group := range []tcpip.Address{addr1, addr2, addr3} {
+ if !mgp.isLocallyJoined(group) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
+ }
+ }
+
+ // Should send the initial set of unsolcited reports.
+ mgp.initializeGroups()
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+// TestGroupStateNonMember tests that groups do not send packets when in the
+// non-member state, but are still considered locally joined.
+func TestGroupStateNonMember(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+ mgp.setEnabled(false)
+
+ // Joining groups should not send any reports.
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should not send any reports.
+ mgp.handleQuery(addr1, time.Nanosecond)
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Leaving groups should not send any leave messages.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestQueuedPackets(t *testing.T) {
+ clock := faketime.NewManualClock()
+ mgp := mockMulticastGroupProtocol{t: t}
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ // Joining should trigger a SendReport, but mgp should report that we did not
+ // send the packet.
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report timer should have been cancelled since we did not send
+ // the initial report earlier.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to successfully send the report.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send (we should be idle).
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query but mock being unable to send reports again.
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to send reports again - we should have a packet queued to
+ // send.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query again, but mock being unable to send reports.
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report should should transition us into the idle member state,
+ // even if we had a packet queued. We should no longer have any packets to
+ // send.
+ mgp.handleReport(addr1)
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // When we fail to send the initial set of reports, incoming reports should
+ // not affect a newly joined group's reports from being sent.
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.handleReport(addr2)
+ // Attempting to send queued reports while still unable to send reports should
+ // not change the host state.
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Mock being able to successfully send the report.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
new file mode 100644
index 000000000..898f8b356
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// LINT.IfChange(MultiCounterIPStats)
+
+// MultiCounterIPStats holds IP statistics, each counter may have several
+// versions.
+type MultiCounterIPStats struct {
+ // PacketsReceived is the total number of IP packets received from the link
+ // layer.
+ PacketsReceived tcpip.MultiCounterStat
+
+ // DisabledPacketsReceived is the total number of IP packets received from the
+ // link layer when the IP layer is disabled.
+ DisabledPacketsReceived tcpip.MultiCounterStat
+
+ // InvalidDestinationAddressesReceived is the total number of IP packets
+ // received with an unknown or invalid destination address.
+ InvalidDestinationAddressesReceived tcpip.MultiCounterStat
+
+ // InvalidSourceAddressesReceived is the total number of IP packets received
+ // with a source address that should never have been received on the wire.
+ InvalidSourceAddressesReceived tcpip.MultiCounterStat
+
+ // PacketsDelivered is the total number of incoming IP packets that are
+ // successfully delivered to the transport layer.
+ PacketsDelivered tcpip.MultiCounterStat
+
+ // PacketsSent is the total number of IP packets sent via WritePacket.
+ PacketsSent tcpip.MultiCounterStat
+
+ // OutgoingPacketErrors is the total number of IP packets which failed to
+ // write to a link-layer endpoint.
+ OutgoingPacketErrors tcpip.MultiCounterStat
+
+ // MalformedPacketsReceived is the total number of IP Packets that were
+ // dropped due to the IP packet header failing validation checks.
+ MalformedPacketsReceived tcpip.MultiCounterStat
+
+ // MalformedFragmentsReceived is the total number of IP Fragments that were
+ // dropped due to the fragment failing validation checks.
+ MalformedFragmentsReceived tcpip.MultiCounterStat
+
+ // IPTablesPreroutingDropped is the total number of IP packets dropped in the
+ // Prerouting chain.
+ IPTablesPreroutingDropped tcpip.MultiCounterStat
+
+ // IPTablesInputDropped is the total number of IP packets dropped in the Input
+ // chain.
+ IPTablesInputDropped tcpip.MultiCounterStat
+
+ // IPTablesOutputDropped is the total number of IP packets dropped in the
+ // Output chain.
+ IPTablesOutputDropped tcpip.MultiCounterStat
+
+ // OptionTSReceived is the number of Timestamp options seen.
+ OptionTSReceived tcpip.MultiCounterStat
+
+ // OptionRRReceived is the number of Record Route options seen.
+ OptionRRReceived tcpip.MultiCounterStat
+
+ // OptionUnknownReceived is the number of unknown IP options seen.
+ OptionUnknownReceived tcpip.MultiCounterStat
+}
+
+// Init sets internal counters to track a and b counters.
+func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
+ m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+ m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
+ m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived)
+ m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived)
+ m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered)
+ m.PacketsSent.Init(a.PacketsSent, b.PacketsSent)
+ m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors)
+ m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived)
+ m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived)
+ m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
+ m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
+ m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
+ m.OptionTSReceived.Init(a.OptionTSReceived, b.OptionTSReceived)
+ m.OptionRRReceived.Init(a.OptionRRReceived, b.OptionRRReceived)
+ m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived)
+}
+
+// LINT.ThenChange(:MultiCounterIPStats, ../../tcpip.go:IPStats)
diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
new file mode 100644
index 000000000..1c4f583c7
--- /dev/null
+++ b/pkg/tcpip/network/internal/testutil/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ srcs = [
+ "testutil.go",
+ "testutil_unsafe.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/network/arp:__pkg__",
+ "//pkg/tcpip/network/internal/fragmentation:__pkg__",
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
new file mode 100644
index 000000000..f5fa77b65
--- /dev/null
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -0,0 +1,197 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil defines types and functions used to test Network Layer
+// functionality such as IP fragmentation.
+package testutil
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// MockLinkEndpoint is an endpoint used for testing, it stores packets written
+// to it and can mock errors.
+type MockLinkEndpoint struct {
+ // WrittenPackets is where packets written to the endpoint are stored.
+ WrittenPackets []*stack.PacketBuffer
+
+ mtu uint32
+ err tcpip.Error
+ allowPackets int
+}
+
+// NewMockLinkEndpoint creates a new MockLinkEndpoint.
+//
+// err is the error that will be returned once allowPackets packets are written
+// to the endpoint.
+func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint {
+ return &MockLinkEndpoint{
+ mtu: mtu,
+ err: err,
+ allowPackets: allowPackets,
+ }
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ var n int
+
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := ep.WritePacket(r, gso, protocol, pkt); err != nil {
+ return n, err
+ }
+ n++
+ }
+
+ return n, nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*MockLinkEndpoint) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*MockLinkEndpoint) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
+// how many random bytes will be copied in the Transport Header.
+// extraHeaderReserveLength indicates how much extra space will be reserved for
+// the other headers. The payload is made from Views of the sizes listed in
+// viewSizes.
+func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
+ var views buffer.VectorisedView
+
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ if _, err := rand.Read(newView); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ views.AppendView(newView)
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
+ Data: views,
+ })
+ pkt.NetworkProtocolNumber = proto
+ if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ return pkt
+}
+
+func checkFieldCounts(ref, multi reflect.Value) error {
+ refTypeName := ref.Type().Name()
+ multiTypeName := multi.Type().Name()
+ refNumField := ref.NumField()
+ multiNumField := multi.NumField()
+
+ if refNumField != multiNumField {
+ return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
+ }
+
+ return nil
+}
+
+func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
+ s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
+ if !ok {
+ return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
+ }
+
+ // The field names are expected to match (case insensitive).
+ if !strings.EqualFold(refName, multiName) {
+ return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
+ }
+
+ base := (*s).Value()
+ m.Increment()
+ if (*s).Value() != base+1 {
+ return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
+ }
+
+ return nil
+}
+
+// ValidateMultiCounterStats verifies that every counter stored in multi is
+// correctly tracking its counterpart in the given counters.
+func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
+ for _, c := range counters {
+ if err := checkFieldCounts(c, multi); err != nil {
+ return err
+ }
+ }
+
+ for i := 0; i < multi.NumField(); i++ {
+ multiName := multi.Type().Field(i).Name
+ multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
+
+ if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
+ for _, c := range counters {
+ if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
+ return err
+ }
+ }
+ } else {
+ var countersNextField []reflect.Value
+ for _, c := range counters {
+ countersNextField = append(countersNextField, c.Field(i))
+ }
+ if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go
new file mode 100644
index 000000000..5ff764800
--- /dev/null
+++ b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go
@@ -0,0 +1,26 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafeExposeUnexportedFields takes a Value and returns a version of it in
+// which even unexported fields can be read and written.
+func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value {
+ return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem()
+}