From 18e993eb4f2e6db829acfb5e8725f7d12f73ab67 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 9 Feb 2021 11:48:08 -0800 Subject: Move network internal code to internal package Utilities written to be common across IPv4/IPv6 are not planned to be available for public use. https://golang.org/doc/go1.4#internalpackages PiperOrigin-RevId: 356554862 --- pkg/tcpip/network/arp/BUILD | 2 +- pkg/tcpip/network/arp/stats_test.go | 2 +- pkg/tcpip/network/fragmentation/BUILD | 51 -- pkg/tcpip/network/fragmentation/fragmentation.go | 339 --------- .../network/fragmentation/fragmentation_test.go | 638 ---------------- pkg/tcpip/network/fragmentation/reassembler.go | 182 ----- .../network/fragmentation/reassembler_test.go | 233 ------ pkg/tcpip/network/internal/fragmentation/BUILD | 54 ++ .../internal/fragmentation/fragmentation.go | 339 +++++++++ .../internal/fragmentation/fragmentation_test.go | 638 ++++++++++++++++ .../network/internal/fragmentation/reassembler.go | 182 +++++ .../internal/fragmentation/reassembler_test.go | 233 ++++++ pkg/tcpip/network/internal/ip/BUILD | 17 +- .../internal/ip/generic_multicast_protocol.go | 696 ++++++++++++++++++ .../internal/ip/generic_multicast_protocol_test.go | 805 ++++++++++++++++++++ pkg/tcpip/network/internal/ip/stats.go | 100 +++ pkg/tcpip/network/internal/testutil/BUILD | 23 + pkg/tcpip/network/internal/testutil/testutil.go | 197 +++++ .../network/internal/testutil/testutil_unsafe.go | 26 + pkg/tcpip/network/ip/BUILD | 29 - pkg/tcpip/network/ip/generic_multicast_protocol.go | 696 ------------------ .../network/ip/generic_multicast_protocol_test.go | 812 --------------------- pkg/tcpip/network/ip/stats.go | 100 --- pkg/tcpip/network/ipv4/BUILD | 8 +- pkg/tcpip/network/ipv4/igmp.go | 2 +- pkg/tcpip/network/ipv4/ipv4.go | 2 +- pkg/tcpip/network/ipv4/ipv4_test.go | 2 +- pkg/tcpip/network/ipv4/stats.go | 2 +- pkg/tcpip/network/ipv4/stats_test.go | 2 +- pkg/tcpip/network/ipv6/BUILD | 5 +- pkg/tcpip/network/ipv6/ipv6.go | 2 +- pkg/tcpip/network/ipv6/ipv6_test.go | 2 +- pkg/tcpip/network/ipv6/mld.go | 2 +- pkg/tcpip/network/ipv6/stats.go | 2 +- pkg/tcpip/network/testutil/BUILD | 23 - pkg/tcpip/network/testutil/testutil.go | 197 ----- pkg/tcpip/network/testutil/testutil_unsafe.go | 26 - 37 files changed, 3324 insertions(+), 3347 deletions(-) delete mode 100644 pkg/tcpip/network/fragmentation/BUILD delete mode 100644 pkg/tcpip/network/fragmentation/fragmentation.go delete mode 100644 pkg/tcpip/network/fragmentation/fragmentation_test.go delete mode 100644 pkg/tcpip/network/fragmentation/reassembler.go delete mode 100644 pkg/tcpip/network/fragmentation/reassembler_test.go create mode 100644 pkg/tcpip/network/internal/fragmentation/BUILD create mode 100644 pkg/tcpip/network/internal/fragmentation/fragmentation.go create mode 100644 pkg/tcpip/network/internal/fragmentation/fragmentation_test.go create mode 100644 pkg/tcpip/network/internal/fragmentation/reassembler.go create mode 100644 pkg/tcpip/network/internal/fragmentation/reassembler_test.go create mode 100644 pkg/tcpip/network/internal/ip/generic_multicast_protocol.go create mode 100644 pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go create mode 100644 pkg/tcpip/network/internal/ip/stats.go create mode 100644 pkg/tcpip/network/internal/testutil/BUILD create mode 100644 pkg/tcpip/network/internal/testutil/testutil.go create mode 100644 pkg/tcpip/network/internal/testutil/testutil_unsafe.go delete mode 100644 pkg/tcpip/network/ip/BUILD delete mode 100644 pkg/tcpip/network/ip/generic_multicast_protocol.go delete mode 100644 pkg/tcpip/network/ip/generic_multicast_protocol_test.go delete mode 100644 pkg/tcpip/network/ip/stats.go delete mode 100644 pkg/tcpip/network/testutil/BUILD delete mode 100644 pkg/tcpip/network/testutil/testutil.go delete mode 100644 pkg/tcpip/network/testutil/testutil_unsafe.go (limited to 'pkg') diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 29c8cdffd..d59d678b2 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -46,7 +46,7 @@ go_test( library = ":arp", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/testutil", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index 65c708ac4..e867b3c3f 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -19,7 +19,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD deleted file mode 100644 index 429af69ee..000000000 --- a/pkg/tcpip/network/fragmentation/BUILD +++ /dev/null @@ -1,51 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "reassembler_list", - out = "reassembler_list.go", - package = "fragmentation", - prefix = "reassembler", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*reassembler", - "Linker": "*reassembler", - }, -) - -go_library( - name = "fragmentation", - srcs = [ - "fragmentation.go", - "reassembler.go", - "reassembler_list.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "fragmentation_test", - size = "small", - srcs = [ - "fragmentation_test.go", - "reassembler_test.go", - ], - library = ":fragmentation", - deps = [ - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/network/testutil", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go deleted file mode 100644 index 243738951..000000000 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ /dev/null @@ -1,339 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package fragmentation contains the implementation of IP fragmentation. -// It is based on RFC 791, RFC 815 and RFC 8200. -package fragmentation - -import ( - "errors" - "fmt" - "log" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - // HighFragThreshold is the threshold at which we start trimming old - // fragmented packets. Linux uses a default value of 4 MB. See - // net.ipv4.ipfrag_high_thresh for more information. - HighFragThreshold = 4 << 20 // 4MB - - // LowFragThreshold is the threshold we reach to when we start dropping - // older fragmented packets. It's important that we keep enough room for newer - // packets to be re-assembled. Hence, this needs to be lower than - // HighFragThreshold enough. Linux uses a default value of 3 MB. See - // net.ipv4.ipfrag_low_thresh for more information. - LowFragThreshold = 3 << 20 // 3MB - - // minBlockSize is the minimum block size for fragments. - minBlockSize = 1 -) - -var ( - // ErrInvalidArgs indicates to the caller that an invalid argument was - // provided. - ErrInvalidArgs = errors.New("invalid args") - - // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps - // with another one. - ErrFragmentOverlap = errors.New("overlapping fragments") - - // ErrFragmentConflict indicates that, during reassembly, some fragments are - // in conflict with one another. - ErrFragmentConflict = errors.New("conflicting fragments") -) - -// FragmentID is the identifier for a fragment. -type FragmentID struct { - // Source is the source address of the fragment. - Source tcpip.Address - - // Destination is the destination address of the fragment. - Destination tcpip.Address - - // ID is the identification value of the fragment. - // - // This is a uint32 because IPv6 uses a 32-bit identification value. - ID uint32 - - // The protocol for the packet. - Protocol uint8 -} - -// Fragmentation is the main structure that other modules -// of the stack should use to implement IP Fragmentation. -type Fragmentation struct { - mu sync.Mutex - highLimit int - lowLimit int - reassemblers map[FragmentID]*reassembler - rList reassemblerList - memSize int - timeout time.Duration - blockSize uint16 - clock tcpip.Clock - releaseJob *tcpip.Job - timeoutHandler TimeoutHandler -} - -// TimeoutHandler is consulted if a packet reassembly has timed out. -type TimeoutHandler interface { - // OnReassemblyTimeout will be called with the first fragment (or nil, if the - // first fragment has not been received) of a packet whose reassembly has - // timed out. - OnReassemblyTimeout(pkt *stack.PacketBuffer) -} - -// NewFragmentation creates a new Fragmentation. -// -// blockSize specifies the fragment block size, in bytes. -// -// highMemoryLimit specifies the limit on the memory consumed -// by the fragments stored by Fragmentation (overhead of internal data-structures -// is not accounted). Fragments are dropped when the limit is reached. -// -// lowMemoryLimit specifies the limit on which we will reach by dropping -// fragments after reaching highMemoryLimit. -// -// reassemblingTimeout specifies the maximum time allowed to reassemble a packet. -// Fragments are lazily evicted only when a new a packet with an -// already existing fragmentation-id arrives after the timeout. -func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation { - if lowMemoryLimit >= highMemoryLimit { - lowMemoryLimit = highMemoryLimit - } - - if lowMemoryLimit < 0 { - lowMemoryLimit = 0 - } - - if blockSize < minBlockSize { - blockSize = minBlockSize - } - - f := &Fragmentation{ - reassemblers: make(map[FragmentID]*reassembler), - highLimit: highMemoryLimit, - lowLimit: lowMemoryLimit, - timeout: reassemblingTimeout, - blockSize: blockSize, - clock: clock, - timeoutHandler: timeoutHandler, - } - f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) - - return f -} - -// Process processes an incoming fragment belonging to an ID and returns a -// complete packet and its protocol number when all the packets belonging to -// that ID have been received. -// -// [first, last] is the range of the fragment bytes. -// -// first must be a multiple of the block size f is configured with. The size -// of the fragment data must be a multiple of the block size, unless there are -// no fragments following this fragment (more set to false). -// -// proto is the protocol number marked in the fragment being processed. It has -// to be given here outside of the FragmentID struct because IPv6 should not use -// the protocol to identify a fragment. -func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( - *stack.PacketBuffer, uint8, bool, error) { - if first > last { - return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) - } - - if first%f.blockSize != 0 { - return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) - } - - fragmentSize := last - first + 1 - if more && fragmentSize%f.blockSize != 0 { - return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) - } - - if l := pkt.Data.Size(); l != int(fragmentSize) { - return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) - } - - f.mu.Lock() - r, ok := f.reassemblers[id] - if !ok { - r = newReassembler(id, f.clock) - f.reassemblers[id] = r - wasEmpty := f.rList.Empty() - f.rList.PushFront(r) - if wasEmpty { - // If we have just pushed a first reassembler into an empty list, we - // should kickstart the release job. The release job will keep - // rescheduling itself until the list becomes empty. - f.releaseReassemblersLocked() - } - } - f.mu.Unlock() - - resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt) - if err != nil { - // We probably got an invalid sequence of fragments. Just - // discard the reassembler and move on. - f.mu.Lock() - f.release(r, false /* timedOut */) - f.mu.Unlock() - return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err) - } - f.mu.Lock() - f.memSize += memConsumed - if done { - f.release(r, false /* timedOut */) - } - // Evict reassemblers if we are consuming more memory than highLimit until - // we reach lowLimit. - if f.memSize > f.highLimit { - for f.memSize > f.lowLimit { - tail := f.rList.Back() - if tail == nil { - break - } - f.release(tail, false /* timedOut */) - } - } - f.mu.Unlock() - return resPkt, firstFragmentProto, done, nil -} - -func (f *Fragmentation) release(r *reassembler, timedOut bool) { - // Before releasing a fragment we need to check if r is already marked as done. - // Otherwise, we would delete it twice. - if r.checkDoneOrMark() { - return - } - - delete(f.reassemblers, r.id) - f.rList.Remove(r) - f.memSize -= r.memSize - if f.memSize < 0 { - log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize) - f.memSize = 0 - } - - if h := f.timeoutHandler; timedOut && h != nil { - h.OnReassemblyTimeout(r.pkt) - } -} - -// releaseReassemblersLocked releases already-expired reassemblers, then -// schedules the job to call back itself for the remaining reassemblers if -// any. This function must be called with f.mu locked. -func (f *Fragmentation) releaseReassemblersLocked() { - now := f.clock.NowMonotonic() - for { - // The reassembler at the end of the list is the oldest. - r := f.rList.Back() - if r == nil { - // The list is empty. - break - } - elapsed := time.Duration(now-r.creationTime) * time.Nanosecond - if f.timeout > elapsed { - // If the oldest reassembler has not expired, schedule the release - // job so that this function is called back when it has expired. - f.releaseJob.Schedule(f.timeout - elapsed) - break - } - // If the oldest reassembler has already expired, release it. - f.release(r, true /* timedOut*/) - } -} - -// PacketFragmenter is the book-keeping struct for packet fragmentation. -type PacketFragmenter struct { - transportHeader buffer.View - data buffer.VectorisedView - reserve int - fragmentPayloadLen int - fragmentCount int - currentFragment int - fragmentOffset int -} - -// MakePacketFragmenter prepares the struct needed for packet fragmentation. -// -// pkt is the packet to be fragmented. -// -// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can -// have. -// -// reserve is the number of bytes that should be reserved for the headers in -// each generated fragment. -func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter { - // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be - // repeated in each fragment. However we do not currently support any header - // of that kind yet, so the following computation is valid for both IPv4 and - // IPv6. - // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are - // supported for outbound packets, the fragmentable data should not include - // these headers. - var fragmentableData buffer.VectorisedView - fragmentableData.AppendView(pkt.TransportHeader().View()) - fragmentableData.Append(pkt.Data) - fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen - - return PacketFragmenter{ - data: fragmentableData, - reserve: reserve, - fragmentPayloadLen: int(fragmentPayloadLen), - fragmentCount: int(fragmentCount), - } -} - -// BuildNextFragment returns a packet with the payload of the next fragment, -// along with the fragment's offset, the number of bytes copied and a boolean -// indicating if there are more fragments left or not. If this function is -// called again after it indicated that no more fragments were left, it will -// panic. -// -// Note that the returned packet will not have its network and link headers -// populated, but space for them will be reserved. The transport header will be -// stored in the packet's data. -func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) { - if pf.currentFragment >= pf.fragmentCount { - panic("BuildNextFragment should not be called again after the last fragment was returned") - } - - fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: pf.reserve, - }) - - // Copy data for the fragment. - copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) - - offset := pf.fragmentOffset - pf.fragmentOffset += copied - pf.currentFragment++ - more := pf.currentFragment != pf.fragmentCount - - return fragPkt, offset, copied, more -} - -// RemainingFragmentCount returns the number of fragments left to be built. -func (pf *PacketFragmenter) RemainingFragmentCount() int { - return pf.fragmentCount - pf.currentFragment -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go deleted file mode 100644 index 905bbc19b..000000000 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ /dev/null @@ -1,638 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "errors" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// reassembleTimeout is dummy timeout used for testing, where the clock never -// advances. -const reassembleTimeout = 1 - -// vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -func pkt(size int, pieces ...string) *stack.PacketBuffer { - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv(size, pieces...), - }) -} - -type processInput struct { - id FragmentID - first uint16 - last uint16 - more bool - proto uint8 - pkt *stack.PacketBuffer -} - -type processOutput struct { - vv buffer.VectorisedView - proto uint8 - done bool -} - -var processTestCases = []struct { - comment string - in []processInput - out []processOutput -}{ - { - comment: "One ID", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, - { - comment: "Next Header protocol mismatch", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), proto: 6, done: true}, - }, - }, - { - comment: "Two IDs", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "ab", "cd"), done: true}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, -} - -func TestFragmentationProcess(t *testing.T) { - for _, c := range processTestCases { - t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) - firstFragmentProto := c.in[0].proto - for i, in := range c.in { - resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) - if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, err) - } - if done != c.out[i].done { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", - in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) - } - if c.out[i].done { - if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) - } - if firstFragmentProto != proto { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", - in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) - } - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Process(%d) did not remove buffer from reassemblers", i) - } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Process(%d) did not remove buffer from rList", i) - } - } - } - } - }) - } -} - -func TestReassemblingTimeout(t *testing.T) { - const ( - reassemblyTimeout = time.Millisecond - protocol = 0xff - ) - - type fragment struct { - first uint16 - last uint16 - more bool - data string - } - - type event struct { - // name is a nickname of this event. - name string - - // clockAdvance is a duration to advance the clock. The clock advances - // before a fragment specified in the fragment field is processed. - clockAdvance time.Duration - - // fragment is a fragment to process. This can be nil if there is no - // fragment to process. - fragment *fragment - - // expectDone is true if the fragmentation instance should report the - // reassembly is done after the fragment is processd. - expectDone bool - - // memSizeAfterEvent is the expected memory size of the fragmentation - // instance after the event. - memSizeAfterEvent int - } - - memSizeOfFrags := func(frags ...*fragment) int { - var size int - for _, frag := range frags { - size += pkt(len(frag.data), frag.data).MemSize() - } - return size - } - - half1 := &fragment{first: 0, last: 0, more: true, data: "0"} - half2 := &fragment{first: 1, last: 1, more: false, data: "1"} - - tests := []struct { - name string - events []event - }{ - { - name: "half1 and half2 are reassembled successfully", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half2", - fragment: half2, - expectDone: true, - memSizeAfterEvent: 0, - }, - }, - }, - { - name: "half1 timeout, half2 timeout", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - { - name: "half2", - fragment: half2, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) - for _, event := range test.events { - clock.Advance(event.clockAdvance) - if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) - if err != nil { - t.Fatalf("%s: f.Process failed: %s", event.name, err) - } - if done != event.expectDone { - t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) - } - } - if got, want := f.memSize, event.memSizeAfterEvent; got != want { - t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) - } - } - }) - } -} - -func TestMemoryLimits(t *testing.T) { - lowLimit := pkt(1, "0").MemSize() - highLimit := 3 * lowLimit // Allow at most 3 such packets. - f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) - // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) - // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) - - // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be - // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) - - if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { - t.Errorf("Memory limits are not respected: id=0 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { - t.Errorf("Memory limits are not respected: id=1 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { - t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") - } -} - -func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - memSize := pkt(1, "0").MemSize() - f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - - if got, want := f.memSize, memSize; got != want { - t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) - } -} - -func TestErrors(t *testing.T) { - tests := []struct { - name string - blockSize uint16 - first uint16 - last uint16 - more bool - data string - err error - }{ - { - name: "exact block size without more", - blockSize: 2, - first: 2, - last: 3, - more: false, - data: "01", - }, - { - name: "exact block size with more", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "01", - }, - { - name: "exact block size with more and extra data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "012", - err: ErrInvalidArgs, - }, - { - name: "exact block size with more and too little data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size with more", - blockSize: 2, - first: 2, - last: 2, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size without more", - blockSize: 2, - first: 2, - last: 2, - more: false, - data: "0", - }, - { - name: "first not a multiple of block size", - blockSize: 2, - first: 3, - last: 4, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - { - name: "first more than last", - blockSize: 2, - first: 4, - last: 3, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) - if !errors.Is(err, test.err) { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) - } - if done { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) - } - }) - } -} - -type fragmentInfo struct { - remaining int - copied int - offset int - more bool -} - -func TestPacketFragmenter(t *testing.T) { - const ( - reserve = 60 - proto = 0 - ) - - tests := []struct { - name string - fragmentPayloadLen uint32 - transportHeaderLen int - payloadSize int - wantFragments []fragmentInfo - }{ - { - name: "Packet exactly fits in MTU", - fragmentPayloadLen: 1280, - transportHeaderLen: 0, - payloadSize: 1280, - wantFragments: []fragmentInfo{ - {remaining: 0, copied: 1280, offset: 0, more: false}, - }, - }, - { - name: "Packet exactly does not fit in MTU", - fragmentPayloadLen: 1000, - transportHeaderLen: 0, - payloadSize: 1001, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 1000, offset: 0, more: true}, - {remaining: 0, copied: 1, offset: 1000, more: false}, - }, - }, - { - name: "Packet has a transport header", - fragmentPayloadLen: 560, - transportHeaderLen: 40, - payloadSize: 560, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 560, offset: 0, more: true}, - {remaining: 0, copied: 40, offset: 560, more: false}, - }, - }, - { - name: "Packet has a huge transport header", - fragmentPayloadLen: 500, - transportHeaderLen: 1300, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {remaining: 3, copied: 500, offset: 0, more: true}, - {remaining: 2, copied: 500, offset: 500, more: true}, - {remaining: 1, copied: 500, offset: 1000, more: true}, - {remaining: 0, copied: 300, offset: 1500, more: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) - var originalPayload buffer.VectorisedView - originalPayload.AppendView(pkt.TransportHeader().View()) - originalPayload.Append(pkt.Data) - var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) - for i := 0; ; i++ { - fragPkt, offset, copied, more := pf.BuildNextFragment() - wantFragment := test.wantFragments[i] - if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { - t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) - } - if copied != wantFragment.copied { - t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) - } - if offset != wantFragment.offset { - t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) - } - if more != wantFragment.more { - t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) - } - if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) - } - if got := fragPkt.AvailableHeaderBytes(); got != reserve { - t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) - } - if got := fragPkt.TransportHeader().View().Size(); got != 0 { - t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) - } - reassembledPayload.Append(fragPkt.Data) - if !more { - if i != len(test.wantFragments)-1 { - t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) - } - break - } - } - if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { - t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type testTimeoutHandler struct { - pkt *stack.PacketBuffer -} - -func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { - h.pkt = pkt -} - -func TestTimeoutHandler(t *testing.T) { - const ( - proto = 99 - ) - - pk1 := pkt(1, "1") - pk2 := pkt(1, "2") - - type processParam struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - } - - tests := []struct { - name string - params []processParam - wantError bool - wantPkt *stack.PacketBuffer - }{ - { - name: "onTimeout runs", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "no first fragment", - params: []processParam{ - { - first: 1, - last: 1, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: nil, - }, - { - name: "second pkt is ignored", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - { - first: 0, - last: 0, - more: true, - pkt: pk2, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "invalid args - first is greater than last", - params: []processParam{ - { - first: 1, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: true, - wantPkt: nil, - }, - } - - id := FragmentID{ID: 0} - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - handler := &testTimeoutHandler{pkt: nil} - - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) - - for _, p := range test.params { - if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { - t.Errorf("f.Process error = %s", err) - } - } - if !test.wantError { - r, ok := f.reassemblers[id] - if !ok { - t.Fatal("Reassembler not found") - } - f.release(r, true) - } - switch { - case handler.pkt != nil && test.wantPkt == nil: - t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) - case handler.pkt == nil && test.wantPkt != nil: - t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) - case handler.pkt != nil && test.wantPkt != nil: - if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { - t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go deleted file mode 100644 index 933d63d32..000000000 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "math" - "sort" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type hole struct { - first uint16 - last uint16 - filled bool - final bool - // pkt is the fragment packet if hole is filled. We keep the whole pkt rather - // than the fragmented payload to prevent binding to specific buffer types. - pkt *stack.PacketBuffer -} - -type reassembler struct { - reassemblerEntry - id FragmentID - memSize int - proto uint8 - mu sync.Mutex - holes []hole - filled int - done bool - creationTime int64 - pkt *stack.PacketBuffer -} - -func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { - r := &reassembler{ - id: id, - creationTime: clock.NowMonotonic(), - } - r.holes = append(r.holes, hole{ - first: 0, - last: math.MaxUint16, - filled: false, - final: true, - }) - return r -} - -func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) { - r.mu.Lock() - defer r.mu.Unlock() - if r.done { - // A concurrent goroutine might have already reassembled - // the packet and emptied the heap while this goroutine - // was waiting on the mutex. We don't have to do anything in this case. - return nil, 0, false, 0, nil - } - - var holeFound bool - var memConsumed int - for i := range r.holes { - currentHole := &r.holes[i] - - if last < currentHole.first || currentHole.last < first { - continue - } - // For IPv6, overlaps with an existing fragment are explicitly forbidden by - // RFC 8200 section 4.5: - // If any of the fragments being reassembled overlap with any other - // fragments being reassembled for the same packet, reassembly of that - // packet must be abandoned and all the fragments that have been received - // for that packet must be discarded, and no ICMP error messages should be - // sent. - // - // It is not explicitly forbidden for IPv4, but to keep parity with Linux we - // disallow it as well: - // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 - if first < currentHole.first || currentHole.last < last { - // Incoming fragment only partially fits in the free hole. - return nil, 0, false, 0, ErrFragmentOverlap - } - if !more { - if !currentHole.final || currentHole.filled && currentHole.last != last { - // We have another final fragment, which does not perfectly overlap. - return nil, 0, false, 0, ErrFragmentConflict - } - } - - holeFound = true - if currentHole.filled { - // Incoming fragment is a duplicate. - continue - } - - // We are populating the current hole with the payload and creating a new - // hole for any unfilled ranges on either end. - if first > currentHole.first { - r.holes = append(r.holes, hole{ - first: currentHole.first, - last: first - 1, - filled: false, - final: false, - }) - } - if last < currentHole.last && more { - r.holes = append(r.holes, hole{ - first: last + 1, - last: currentHole.last, - filled: false, - final: currentHole.final, - }) - currentHole.final = false - } - memConsumed = pkt.MemSize() - r.memSize += memConsumed - // Update the current hole to precisely match the incoming fragment. - r.holes[i] = hole{ - first: first, - last: last, - filled: true, - final: currentHole.final, - pkt: pkt, - } - r.filled++ - // For IPv6, it is possible to have different Protocol values between - // fragments of a packet (because, unlike IPv4, the Protocol is not used to - // identify a fragment). In this case, only the Protocol of the first - // fragment must be used as per RFC 8200 Section 4.5. - // - // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP - // options received in the first fragment should be used - and they should - // override options from following fragments. - if first == 0 { - r.pkt = pkt - r.proto = proto - } - - break - } - if !holeFound { - // Incoming fragment is beyond end. - return nil, 0, false, 0, ErrFragmentConflict - } - - // Check if all the holes have been filled and we are ready to reassemble. - if r.filled < len(r.holes) { - return nil, 0, false, memConsumed, nil - } - - sort.Slice(r.holes, func(i, j int) bool { - return r.holes[i].first < r.holes[j].first - }) - - resPkt := r.holes[0].pkt - for i := 1; i < len(r.holes); i++ { - fragPkt := r.holes[i].pkt - fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size()) - } - return resPkt, r.proto, true, memConsumed, nil -} - -func (r *reassembler) checkDoneOrMark() bool { - r.mu.Lock() - prev := r.done - r.done = true - r.mu.Unlock() - return prev -} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go deleted file mode 100644 index 214a93709..000000000 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "bytes" - "math" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type processParams struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - wantDone bool - wantError error -} - -func TestReassemblerProcess(t *testing.T) { - const proto = 99 - - v := func(size int) buffer.View { - payload := buffer.NewView(size) - for i := 1; i < size; i++ { - payload[i] = uint8(i) * 3 - } - return payload - } - - pkt := func(sizes ...int) *stack.PacketBuffer { - var vv buffer.VectorisedView - for _, size := range sizes { - vv.AppendView(v(size)) - } - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - } - - var tests = []struct { - name string - params []processParams - want []hole - wantPkt *stack.PacketBuffer - }{ - { - name: "No fragments", - params: nil, - want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, - }, - { - name: "One fragment at beginning", - params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, - {first: 2, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment in the middle", - params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, - {first: 0, last: 0, filled: false, final: false}, - {first: 3, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment at the end", - params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, - {first: 0, last: 0, filled: false}, - }, - }, - { - name: "One fragment completing a packet", - params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: true}, - }, - wantPkt: pkt(2), - }, - { - name: "Two fragments completing a packet", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a duplicate", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a partial duplicate", - params: []processParams{ - {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, - {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 3, filled: true, final: false}, - {first: 4, last: 5, filled: true, final: true}, - }, - wantPkt: pkt(4, 2), - }, - { - name: "Two overlapping fragments", - params: []processParams{ - {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, - {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, - }, - want: []hole{ - {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, - {first: 11, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "Two final fragments with different ends", - params: []processParams{ - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, - {first: 0, last: 9, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate, with different ends", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - var resPkt *stack.PacketBuffer - var isDone bool - for _, param := range test.params { - pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) - if done != param.wantDone || err != param.wantError { - t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) - } - if done { - resPkt = pkt - isDone = true - } - } - - ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } - cmpPktData := func(a, b *stack.PacketBuffer) bool { - if a == nil || b == nil { - return a == b - } - return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView()) - } - - if isDone { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - // Do not compare pkt in hole. Data will be altered. - cmp.Comparer(ignorePkt), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { - t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) - } - } else { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - cmp.Comparer(cmpPktData), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD new file mode 100644 index 000000000..274f09092 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/BUILD @@ -0,0 +1,54 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "reassembler_list", + out = "reassembler_list.go", + package = "fragmentation", + prefix = "reassembler", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*reassembler", + "Linker": "*reassembler", + }, +) + +go_library( + name = "fragmentation", + srcs = [ + "fragmentation.go", + "reassembler.go", + "reassembler_list.go", + ], + visibility = [ + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], + deps = [ + "//pkg/log", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "fragmentation_test", + size = "small", + srcs = [ + "fragmentation_test.go", + "reassembler_test.go", + ], + library = ":fragmentation", + deps = [ + "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", + "//pkg/tcpip/network/internal/testutil", + "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go new file mode 100644 index 000000000..243738951 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go @@ -0,0 +1,339 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fragmentation contains the implementation of IP fragmentation. +// It is based on RFC 791, RFC 815 and RFC 8200. +package fragmentation + +import ( + "errors" + "fmt" + "log" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // HighFragThreshold is the threshold at which we start trimming old + // fragmented packets. Linux uses a default value of 4 MB. See + // net.ipv4.ipfrag_high_thresh for more information. + HighFragThreshold = 4 << 20 // 4MB + + // LowFragThreshold is the threshold we reach to when we start dropping + // older fragmented packets. It's important that we keep enough room for newer + // packets to be re-assembled. Hence, this needs to be lower than + // HighFragThreshold enough. Linux uses a default value of 3 MB. See + // net.ipv4.ipfrag_low_thresh for more information. + LowFragThreshold = 3 << 20 // 3MB + + // minBlockSize is the minimum block size for fragments. + minBlockSize = 1 +) + +var ( + // ErrInvalidArgs indicates to the caller that an invalid argument was + // provided. + ErrInvalidArgs = errors.New("invalid args") + + // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps + // with another one. + ErrFragmentOverlap = errors.New("overlapping fragments") + + // ErrFragmentConflict indicates that, during reassembly, some fragments are + // in conflict with one another. + ErrFragmentConflict = errors.New("conflicting fragments") +) + +// FragmentID is the identifier for a fragment. +type FragmentID struct { + // Source is the source address of the fragment. + Source tcpip.Address + + // Destination is the destination address of the fragment. + Destination tcpip.Address + + // ID is the identification value of the fragment. + // + // This is a uint32 because IPv6 uses a 32-bit identification value. + ID uint32 + + // The protocol for the packet. + Protocol uint8 +} + +// Fragmentation is the main structure that other modules +// of the stack should use to implement IP Fragmentation. +type Fragmentation struct { + mu sync.Mutex + highLimit int + lowLimit int + reassemblers map[FragmentID]*reassembler + rList reassemblerList + memSize int + timeout time.Duration + blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job + timeoutHandler TimeoutHandler +} + +// TimeoutHandler is consulted if a packet reassembly has timed out. +type TimeoutHandler interface { + // OnReassemblyTimeout will be called with the first fragment (or nil, if the + // first fragment has not been received) of a packet whose reassembly has + // timed out. + OnReassemblyTimeout(pkt *stack.PacketBuffer) +} + +// NewFragmentation creates a new Fragmentation. +// +// blockSize specifies the fragment block size, in bytes. +// +// highMemoryLimit specifies the limit on the memory consumed +// by the fragments stored by Fragmentation (overhead of internal data-structures +// is not accounted). Fragments are dropped when the limit is reached. +// +// lowMemoryLimit specifies the limit on which we will reach by dropping +// fragments after reaching highMemoryLimit. +// +// reassemblingTimeout specifies the maximum time allowed to reassemble a packet. +// Fragments are lazily evicted only when a new a packet with an +// already existing fragmentation-id arrives after the timeout. +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation { + if lowMemoryLimit >= highMemoryLimit { + lowMemoryLimit = highMemoryLimit + } + + if lowMemoryLimit < 0 { + lowMemoryLimit = 0 + } + + if blockSize < minBlockSize { + blockSize = minBlockSize + } + + f := &Fragmentation{ + reassemblers: make(map[FragmentID]*reassembler), + highLimit: highMemoryLimit, + lowLimit: lowMemoryLimit, + timeout: reassemblingTimeout, + blockSize: blockSize, + clock: clock, + timeoutHandler: timeoutHandler, + } + f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) + + return f +} + +// Process processes an incoming fragment belonging to an ID and returns a +// complete packet and its protocol number when all the packets belonging to +// that ID have been received. +// +// [first, last] is the range of the fragment bytes. +// +// first must be a multiple of the block size f is configured with. The size +// of the fragment data must be a multiple of the block size, unless there are +// no fragments following this fragment (more set to false). +// +// proto is the protocol number marked in the fragment being processed. It has +// to be given here outside of the FragmentID struct because IPv6 should not use +// the protocol to identify a fragment. +func (f *Fragmentation) Process( + id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( + *stack.PacketBuffer, uint8, bool, error) { + if first > last { + return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + } + + if first%f.blockSize != 0 { + return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + } + + fragmentSize := last - first + 1 + if more && fragmentSize%f.blockSize != 0 { + return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + } + + if l := pkt.Data.Size(); l != int(fragmentSize) { + return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + } + + f.mu.Lock() + r, ok := f.reassemblers[id] + if !ok { + r = newReassembler(id, f.clock) + f.reassemblers[id] = r + wasEmpty := f.rList.Empty() + f.rList.PushFront(r) + if wasEmpty { + // If we have just pushed a first reassembler into an empty list, we + // should kickstart the release job. The release job will keep + // rescheduling itself until the list becomes empty. + f.releaseReassemblersLocked() + } + } + f.mu.Unlock() + + resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt) + if err != nil { + // We probably got an invalid sequence of fragments. Just + // discard the reassembler and move on. + f.mu.Lock() + f.release(r, false /* timedOut */) + f.mu.Unlock() + return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err) + } + f.mu.Lock() + f.memSize += memConsumed + if done { + f.release(r, false /* timedOut */) + } + // Evict reassemblers if we are consuming more memory than highLimit until + // we reach lowLimit. + if f.memSize > f.highLimit { + for f.memSize > f.lowLimit { + tail := f.rList.Back() + if tail == nil { + break + } + f.release(tail, false /* timedOut */) + } + } + f.mu.Unlock() + return resPkt, firstFragmentProto, done, nil +} + +func (f *Fragmentation) release(r *reassembler, timedOut bool) { + // Before releasing a fragment we need to check if r is already marked as done. + // Otherwise, we would delete it twice. + if r.checkDoneOrMark() { + return + } + + delete(f.reassemblers, r.id) + f.rList.Remove(r) + f.memSize -= r.memSize + if f.memSize < 0 { + log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize) + f.memSize = 0 + } + + if h := f.timeoutHandler; timedOut && h != nil { + h.OnReassemblyTimeout(r.pkt) + } +} + +// releaseReassemblersLocked releases already-expired reassemblers, then +// schedules the job to call back itself for the remaining reassemblers if +// any. This function must be called with f.mu locked. +func (f *Fragmentation) releaseReassemblersLocked() { + now := f.clock.NowMonotonic() + for { + // The reassembler at the end of the list is the oldest. + r := f.rList.Back() + if r == nil { + // The list is empty. + break + } + elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + if f.timeout > elapsed { + // If the oldest reassembler has not expired, schedule the release + // job so that this function is called back when it has expired. + f.releaseJob.Schedule(f.timeout - elapsed) + break + } + // If the oldest reassembler has already expired, release it. + f.release(r, true /* timedOut*/) + } +} + +// PacketFragmenter is the book-keeping struct for packet fragmentation. +type PacketFragmenter struct { + transportHeader buffer.View + data buffer.VectorisedView + reserve int + fragmentPayloadLen int + fragmentCount int + currentFragment int + fragmentOffset int +} + +// MakePacketFragmenter prepares the struct needed for packet fragmentation. +// +// pkt is the packet to be fragmented. +// +// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can +// have. +// +// reserve is the number of bytes that should be reserved for the headers in +// each generated fragment. +func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter { + // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be + // repeated in each fragment. However we do not currently support any header + // of that kind yet, so the following computation is valid for both IPv4 and + // IPv6. + // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are + // supported for outbound packets, the fragmentable data should not include + // these headers. + var fragmentableData buffer.VectorisedView + fragmentableData.AppendView(pkt.TransportHeader().View()) + fragmentableData.Append(pkt.Data) + fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen + + return PacketFragmenter{ + data: fragmentableData, + reserve: reserve, + fragmentPayloadLen: int(fragmentPayloadLen), + fragmentCount: int(fragmentCount), + } +} + +// BuildNextFragment returns a packet with the payload of the next fragment, +// along with the fragment's offset, the number of bytes copied and a boolean +// indicating if there are more fragments left or not. If this function is +// called again after it indicated that no more fragments were left, it will +// panic. +// +// Note that the returned packet will not have its network and link headers +// populated, but space for them will be reserved. The transport header will be +// stored in the packet's data. +func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) { + if pf.currentFragment >= pf.fragmentCount { + panic("BuildNextFragment should not be called again after the last fragment was returned") + } + + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: pf.reserve, + }) + + // Copy data for the fragment. + copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) + + offset := pf.fragmentOffset + pf.fragmentOffset += copied + pf.currentFragment++ + more := pf.currentFragment != pf.fragmentCount + + return fragPkt, offset, copied, more +} + +// RemainingFragmentCount returns the number of fragments left to be built. +func (pf *PacketFragmenter) RemainingFragmentCount() int { + return pf.fragmentCount - pf.currentFragment +} diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go new file mode 100644 index 000000000..47ea3173e --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -0,0 +1,638 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "errors" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// reassembleTimeout is dummy timeout used for testing, where the clock never +// advances. +const reassembleTimeout = 1 + +// vv is a helper to build VectorisedView from different strings. +func vv(size int, pieces ...string) buffer.VectorisedView { + views := make([]buffer.View, len(pieces)) + for i, p := range pieces { + views[i] = []byte(p) + } + + return buffer.NewVectorisedView(size, views) +} + +func pkt(size int, pieces ...string) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv(size, pieces...), + }) +} + +type processInput struct { + id FragmentID + first uint16 + last uint16 + more bool + proto uint8 + pkt *stack.PacketBuffer +} + +type processOutput struct { + vv buffer.VectorisedView + proto uint8 + done bool +} + +var processTestCases = []struct { + comment string + in []processInput + out []processOutput +}{ + { + comment: "One ID", + in: []processInput{ + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "01", "23"), done: true}, + }, + }, + { + comment: "Next Header protocol mismatch", + in: []processInput{ + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "01", "23"), proto: 6, done: true}, + }, + }, + { + comment: "Two IDs", + in: []processInput{ + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, + {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "ab", "cd"), done: true}, + {vv: vv(4, "01", "23"), done: true}, + }, + }, +} + +func TestFragmentationProcess(t *testing.T) { + for _, c := range processTestCases { + t.Run(c.comment, func(t *testing.T) { + f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) + firstFragmentProto := c.in[0].proto + for i, in := range c.in { + resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) + if err != nil { + t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, err) + } + if done != c.out[i].done { + t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", + in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) + } + if c.out[i].done { + if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" { + t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) + } + if firstFragmentProto != proto { + t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", + in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) + } + if _, ok := f.reassemblers[in.id]; ok { + t.Errorf("Process(%d) did not remove buffer from reassemblers", i) + } + for n := f.rList.Front(); n != nil; n = n.Next() { + if n.id == in.id { + t.Errorf("Process(%d) did not remove buffer from rList", i) + } + } + } + } + }) + } +} + +func TestReassemblingTimeout(t *testing.T) { + const ( + reassemblyTimeout = time.Millisecond + protocol = 0xff + ) + + type fragment struct { + first uint16 + last uint16 + more bool + data string + } + + type event struct { + // name is a nickname of this event. + name string + + // clockAdvance is a duration to advance the clock. The clock advances + // before a fragment specified in the fragment field is processed. + clockAdvance time.Duration + + // fragment is a fragment to process. This can be nil if there is no + // fragment to process. + fragment *fragment + + // expectDone is true if the fragmentation instance should report the + // reassembly is done after the fragment is processd. + expectDone bool + + // memSizeAfterEvent is the expected memory size of the fragmentation + // instance after the event. + memSizeAfterEvent int + } + + memSizeOfFrags := func(frags ...*fragment) int { + var size int + for _, frag := range frags { + size += pkt(len(frag.data), frag.data).MemSize() + } + return size + } + + half1 := &fragment{first: 0, last: 0, more: true, data: "0"} + half2 := &fragment{first: 1, last: 1, more: false, data: "1"} + + tests := []struct { + name string + events []event + }{ + { + name: "half1 and half2 are reassembled successfully", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half1), + }, + { + name: "half2", + fragment: half2, + expectDone: true, + memSizeAfterEvent: 0, + }, + }, + }, + { + name: "half1 timeout, half2 timeout", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half1), + }, + { + name: "half1 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + memSizeAfterEvent: memSizeOfFrags(half1), + }, + { + name: "half1 reassembly timeout", + clockAdvance: 1, + memSizeAfterEvent: 0, + }, + { + name: "half2", + fragment: half2, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half2), + }, + { + name: "half2 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + memSizeAfterEvent: memSizeOfFrags(half2), + }, + { + name: "half2 reassembly timeout", + clockAdvance: 1, + memSizeAfterEvent: 0, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) + for _, event := range test.events { + clock.Advance(event.clockAdvance) + if frag := event.fragment; frag != nil { + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) + if err != nil { + t.Fatalf("%s: f.Process failed: %s", event.name, err) + } + if done != event.expectDone { + t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) + } + } + if got, want := f.memSize, event.memSizeAfterEvent; got != want { + t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) + } + } + }) + } +} + +func TestMemoryLimits(t *testing.T) { + lowLimit := pkt(1, "0").MemSize() + highLimit := 3 * lowLimit // Allow at most 3 such packets. + f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) + // Send first fragment with id = 0. + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) + // Send first fragment with id = 1. + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) + // Send first fragment with id = 2. + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) + + // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be + // evicted. + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) + + if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { + t.Errorf("Memory limits are not respected: id=0 has not been evicted.") + } + if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { + t.Errorf("Memory limits are not respected: id=1 has not been evicted.") + } + if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { + t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") + } +} + +func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { + memSize := pkt(1, "0").MemSize() + f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) + // Send first fragment with id = 0. + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + // Send the same packet again. + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + + if got, want := f.memSize, memSize; got != want { + t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) + } +} + +func TestErrors(t *testing.T) { + tests := []struct { + name string + blockSize uint16 + first uint16 + last uint16 + more bool + data string + err error + }{ + { + name: "exact block size without more", + blockSize: 2, + first: 2, + last: 3, + more: false, + data: "01", + }, + { + name: "exact block size with more", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "01", + }, + { + name: "exact block size with more and extra data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "012", + err: ErrInvalidArgs, + }, + { + name: "exact block size with more and too little data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size with more", + blockSize: 2, + first: 2, + last: 2, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size without more", + blockSize: 2, + first: 2, + last: 2, + more: false, + data: "0", + }, + { + name: "first not a multiple of block size", + blockSize: 2, + first: 3, + last: 4, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + { + name: "first more than last", + blockSize: 2, + first: 4, + last: 3, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) + if !errors.Is(err, test.err) { + t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) + } + if done { + t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) + } + }) + } +} + +type fragmentInfo struct { + remaining int + copied int + offset int + more bool +} + +func TestPacketFragmenter(t *testing.T) { + const ( + reserve = 60 + proto = 0 + ) + + tests := []struct { + name string + fragmentPayloadLen uint32 + transportHeaderLen int + payloadSize int + wantFragments []fragmentInfo + }{ + { + name: "Packet exactly fits in MTU", + fragmentPayloadLen: 1280, + transportHeaderLen: 0, + payloadSize: 1280, + wantFragments: []fragmentInfo{ + {remaining: 0, copied: 1280, offset: 0, more: false}, + }, + }, + { + name: "Packet exactly does not fit in MTU", + fragmentPayloadLen: 1000, + transportHeaderLen: 0, + payloadSize: 1001, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 1000, offset: 0, more: true}, + {remaining: 0, copied: 1, offset: 1000, more: false}, + }, + }, + { + name: "Packet has a transport header", + fragmentPayloadLen: 560, + transportHeaderLen: 40, + payloadSize: 560, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 560, offset: 0, more: true}, + {remaining: 0, copied: 40, offset: 560, more: false}, + }, + }, + { + name: "Packet has a huge transport header", + fragmentPayloadLen: 500, + transportHeaderLen: 1300, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {remaining: 3, copied: 500, offset: 0, more: true}, + {remaining: 2, copied: 500, offset: 500, more: true}, + {remaining: 1, copied: 500, offset: 1000, more: true}, + {remaining: 0, copied: 300, offset: 1500, more: false}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) + var originalPayload buffer.VectorisedView + originalPayload.AppendView(pkt.TransportHeader().View()) + originalPayload.Append(pkt.Data) + var reassembledPayload buffer.VectorisedView + pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) + for i := 0; ; i++ { + fragPkt, offset, copied, more := pf.BuildNextFragment() + wantFragment := test.wantFragments[i] + if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { + t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) + } + if copied != wantFragment.copied { + t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) + } + if offset != wantFragment.offset { + t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) + } + if more != wantFragment.more { + t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) + } + if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) + } + if got := fragPkt.AvailableHeaderBytes(); got != reserve { + t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) + } + if got := fragPkt.TransportHeader().View().Size(); got != 0 { + t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) + } + reassembledPayload.Append(fragPkt.Data) + if !more { + if i != len(test.wantFragments)-1 { + t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) + } + break + } + } + if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { + t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + }) + } +} + +type testTimeoutHandler struct { + pkt *stack.PacketBuffer +} + +func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + h.pkt = pkt +} + +func TestTimeoutHandler(t *testing.T) { + const ( + proto = 99 + ) + + pk1 := pkt(1, "1") + pk2 := pkt(1, "2") + + type processParam struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + } + + tests := []struct { + name string + params []processParam + wantError bool + wantPkt *stack.PacketBuffer + }{ + { + name: "onTimeout runs", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: pk1, + }, + { + name: "no first fragment", + params: []processParam{ + { + first: 1, + last: 1, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: nil, + }, + { + name: "second pkt is ignored", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + { + first: 0, + last: 0, + more: true, + pkt: pk2, + }, + }, + wantError: false, + wantPkt: pk1, + }, + { + name: "invalid args - first is greater than last", + params: []processParam{ + { + first: 1, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: true, + wantPkt: nil, + }, + } + + id := FragmentID{ID: 0} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + handler := &testTimeoutHandler{pkt: nil} + + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) + + for _, p := range test.params { + if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { + t.Errorf("f.Process error = %s", err) + } + } + if !test.wantError { + r, ok := f.reassemblers[id] + if !ok { + t.Fatal("Reassembler not found") + } + f.release(r, true) + } + switch { + case handler.pkt != nil && test.wantPkt == nil: + t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) + case handler.pkt == nil && test.wantPkt != nil: + t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) + case handler.pkt != nil && test.wantPkt != nil: + if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { + t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) + } + } + }) + } +} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go new file mode 100644 index 000000000..933d63d32 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -0,0 +1,182 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "math" + "sort" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type hole struct { + first uint16 + last uint16 + filled bool + final bool + // pkt is the fragment packet if hole is filled. We keep the whole pkt rather + // than the fragmented payload to prevent binding to specific buffer types. + pkt *stack.PacketBuffer +} + +type reassembler struct { + reassemblerEntry + id FragmentID + memSize int + proto uint8 + mu sync.Mutex + holes []hole + filled int + done bool + creationTime int64 + pkt *stack.PacketBuffer +} + +func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { + r := &reassembler{ + id: id, + creationTime: clock.NowMonotonic(), + } + r.holes = append(r.holes, hole{ + first: 0, + last: math.MaxUint16, + filled: false, + final: true, + }) + return r +} + +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.done { + // A concurrent goroutine might have already reassembled + // the packet and emptied the heap while this goroutine + // was waiting on the mutex. We don't have to do anything in this case. + return nil, 0, false, 0, nil + } + + var holeFound bool + var memConsumed int + for i := range r.holes { + currentHole := &r.holes[i] + + if last < currentHole.first || currentHole.last < first { + continue + } + // For IPv6, overlaps with an existing fragment are explicitly forbidden by + // RFC 8200 section 4.5: + // If any of the fragments being reassembled overlap with any other + // fragments being reassembled for the same packet, reassembly of that + // packet must be abandoned and all the fragments that have been received + // for that packet must be discarded, and no ICMP error messages should be + // sent. + // + // It is not explicitly forbidden for IPv4, but to keep parity with Linux we + // disallow it as well: + // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 + if first < currentHole.first || currentHole.last < last { + // Incoming fragment only partially fits in the free hole. + return nil, 0, false, 0, ErrFragmentOverlap + } + if !more { + if !currentHole.final || currentHole.filled && currentHole.last != last { + // We have another final fragment, which does not perfectly overlap. + return nil, 0, false, 0, ErrFragmentConflict + } + } + + holeFound = true + if currentHole.filled { + // Incoming fragment is a duplicate. + continue + } + + // We are populating the current hole with the payload and creating a new + // hole for any unfilled ranges on either end. + if first > currentHole.first { + r.holes = append(r.holes, hole{ + first: currentHole.first, + last: first - 1, + filled: false, + final: false, + }) + } + if last < currentHole.last && more { + r.holes = append(r.holes, hole{ + first: last + 1, + last: currentHole.last, + filled: false, + final: currentHole.final, + }) + currentHole.final = false + } + memConsumed = pkt.MemSize() + r.memSize += memConsumed + // Update the current hole to precisely match the incoming fragment. + r.holes[i] = hole{ + first: first, + last: last, + filled: true, + final: currentHole.final, + pkt: pkt, + } + r.filled++ + // For IPv6, it is possible to have different Protocol values between + // fragments of a packet (because, unlike IPv4, the Protocol is not used to + // identify a fragment). In this case, only the Protocol of the first + // fragment must be used as per RFC 8200 Section 4.5. + // + // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP + // options received in the first fragment should be used - and they should + // override options from following fragments. + if first == 0 { + r.pkt = pkt + r.proto = proto + } + + break + } + if !holeFound { + // Incoming fragment is beyond end. + return nil, 0, false, 0, ErrFragmentConflict + } + + // Check if all the holes have been filled and we are ready to reassemble. + if r.filled < len(r.holes) { + return nil, 0, false, memConsumed, nil + } + + sort.Slice(r.holes, func(i, j int) bool { + return r.holes[i].first < r.holes[j].first + }) + + resPkt := r.holes[0].pkt + for i := 1; i < len(r.holes); i++ { + fragPkt := r.holes[i].pkt + fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size()) + } + return resPkt, r.proto, true, memConsumed, nil +} + +func (r *reassembler) checkDoneOrMark() bool { + r.mu.Lock() + prev := r.done + r.done = true + r.mu.Unlock() + return prev +} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go new file mode 100644 index 000000000..214a93709 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go @@ -0,0 +1,233 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "bytes" + "math" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type processParams struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + wantDone bool + wantError error +} + +func TestReassemblerProcess(t *testing.T) { + const proto = 99 + + v := func(size int) buffer.View { + payload := buffer.NewView(size) + for i := 1; i < size; i++ { + payload[i] = uint8(i) * 3 + } + return payload + } + + pkt := func(sizes ...int) *stack.PacketBuffer { + var vv buffer.VectorisedView + for _, size := range sizes { + vv.AppendView(v(size)) + } + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + } + + var tests = []struct { + name string + params []processParams + want []hole + wantPkt *stack.PacketBuffer + }{ + { + name: "No fragments", + params: nil, + want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, + }, + { + name: "One fragment at beginning", + params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, + {first: 2, last: math.MaxUint16, filled: false, final: true}, + }, + }, + { + name: "One fragment in the middle", + params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, + {first: 0, last: 0, filled: false, final: false}, + {first: 3, last: math.MaxUint16, filled: false, final: true}, + }, + }, + { + name: "One fragment at the end", + params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, + {first: 0, last: 0, filled: false}, + }, + }, + { + name: "One fragment completing a packet", + params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true, final: true}, + }, + wantPkt: pkt(2), + }, + { + name: "Two fragments completing a packet", + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false}, + {first: 2, last: 3, filled: true, final: true}, + }, + wantPkt: pkt(2, 2), + }, + { + name: "Two fragments completing a packet with a duplicate", + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false}, + {first: 2, last: 3, filled: true, final: true}, + }, + wantPkt: pkt(2, 2), + }, + { + name: "Two fragments completing a packet with a partial duplicate", + params: []processParams{ + {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, + {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 3, filled: true, final: false}, + {first: 4, last: 5, filled: true, final: true}, + }, + wantPkt: pkt(4, 2), + }, + { + name: "Two overlapping fragments", + params: []processParams{ + {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, + {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, + }, + want: []hole{ + {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, + {first: 11, last: math.MaxUint16, filled: false, final: true}, + }, + }, + { + name: "Two final fragments with different ends", + params: []processParams{ + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, + {first: 0, last: 9, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + }, + want: []hole{ + {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, + {first: 0, last: 4, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate, with different ends", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, + {first: 0, last: 4, filled: false, final: false}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + var resPkt *stack.PacketBuffer + var isDone bool + for _, param := range test.params { + pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + if done != param.wantDone || err != param.wantError { + t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) + } + if done { + resPkt = pkt + isDone = true + } + } + + ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } + cmpPktData := func(a, b *stack.PacketBuffer) bool { + if a == nil || b == nil { + return a == b + } + return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView()) + } + + if isDone { + if diff := cmp.Diff( + test.want, r.holes, + cmp.AllowUnexported(hole{}), + // Do not compare pkt in hole. Data will be altered. + cmp.Comparer(ignorePkt), + ); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { + t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) + } + } else { + if diff := cmp.Diff( + test.want, r.holes, + cmp.AllowUnexported(hole{}), + cmp.Comparer(cmpPktData), + ); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD index 0f55a9770..d21b4c7ef 100644 --- a/pkg/tcpip/network/internal/ip/BUILD +++ b/pkg/tcpip/network/internal/ip/BUILD @@ -4,8 +4,16 @@ package(licenses = ["notice"]) go_library( name = "ip", - srcs = ["duplicate_address_detection.go"], - visibility = ["//visibility:public"], + srcs = [ + "duplicate_address_detection.go", + "generic_multicast_protocol.go", + "stats.go", + ], + visibility = [ + "//pkg/tcpip/network/arp:__pkg__", + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], deps = [ "//pkg/sync", "//pkg/tcpip", @@ -16,7 +24,10 @@ go_library( go_test( name = "ip_x_test", size = "small", - srcs = ["duplicate_address_detection_test.go"], + srcs = [ + "duplicate_address_detection_test.go", + "generic_multicast_protocol_test.go", + ], deps = [ ":ip", "//pkg/sync", diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go new file mode 100644 index 000000000..b9f129728 --- /dev/null +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -0,0 +1,696 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ip holds IPv4/IPv6 common utilities. +package ip + +import ( + "fmt" + "math/rand" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// hostState is the state a host may be in for a multicast group. +type hostState int + +// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 +// (RFC 2710 section 5). Even though the states are generic across both IGMPv2 +// and MLDv1, IGMPv2 terminology will be used. +// +// ______________receive query______________ +// | | +// | _____send or receive report_____ | +// | | | | +// V | V | +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | +// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ +// | ^ | ^ | ^ | ^ +// | | | | | | | | +// ---------- ------- ---------- ------------- +// initialize new send inital fail to send send or receive +// group membership report delayed report report +// +// Not shown in the diagram above, but any state may transition into the non +// member state when a group is left. +const ( + // nonMember is the "'Non-Member' state, when the host does not belong to the + // group on the interface. This is the initial state for all memberships on + // all network interfaces; it requires no storage in the host." + // + // 'Non-Listener' is the MLDv1 term used to describe this state. + // + // This state is used to keep track of groups that have been joined locally, + // but without advertising the membership to the network. + nonMember hostState = iota + + // pendingMember is a newly joined member that is waiting to successfully send + // the initial set of reports. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the initial report needs to be sent. + // + // MAY NOT transition to the idle member state from this state. + pendingMember + + // delayingMember is the "'Delaying Member' state, when the host belongs to + // the group on the interface and has a report delay timer running for that + // membership." + // + // 'Delaying Listener' is the MLDv1 term used to describe this state. + delayingMember + + // queuedDelayingMember is a delayingMember that failed to send a report after + // its delayed report timer fired. Hosts in this state are waiting to attempt + // retransmission of the delayed report. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the delayed report needs to be sent. + // + // May transition to idle member if a report is received for a group. + queuedDelayingMember + + // idleMember is the "Idle Member" state, when the host belongs to the group + // on the interface and does not have a report delay timer running for that + // membership. + // + // 'Idle Listener' is the MLDv1 term used to describe this state. + idleMember +) + +func (s hostState) isDelayingMember() bool { + switch s { + case nonMember, pendingMember, idleMember: + return false + case delayingMember, queuedDelayingMember: + return true + default: + panic(fmt.Sprintf("unrecognized host state = %d", s)) + } +} + +// multicastGroupState holds the Generic Multicast Protocol state for a +// multicast group. +type multicastGroupState struct { + // joins is the number of times the group has been joined. + joins uint64 + + // state holds the host's state for the group. + state hostState + + // lastToSendReport is true if we sent the last report for the group. It is + // used to track whether there are other hosts on the subnet that are also + // members of the group. + // + // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page + // 8 for MLDv1. + lastToSendReport bool + + // delayedReportJob is used to delay sending responses to membership report + // messages in order to reduce duplicate reports from multiple hosts on the + // interface. + // + // Must not be nil. + delayedReportJob *tcpip.Job + + // delyedReportJobFiresAt is the time when the delayed report job will fire. + // + // A zero value indicates that the job is not scheduled. + delayedReportJobFiresAt time.Time +} + +func (m *multicastGroupState) cancelDelayedReportJob() { + m.delayedReportJob.Cancel() + m.delayedReportJobFiresAt = time.Time{} +} + +// GenericMulticastProtocolOptions holds options for the generic multicast +// protocol. +type GenericMulticastProtocolOptions struct { + // Rand is the source of random numbers. + Rand *rand.Rand + + // Clock is the clock used to create timers. + Clock tcpip.Clock + + // Protocol is the implementation of the variant of multicast group protocol + // in use. + Protocol MulticastGroupProtocol + + // MaxUnsolicitedReportDelay is the maximum amount of time to wait between + // transmitting unsolicited reports. + // + // Unsolicited reports are transmitted when a group is newly joined. + MaxUnsolicitedReportDelay time.Duration + + // AllNodesAddress is a multicast address that all nodes on a network should + // be a member of. + // + // This address will not have the generic multicast protocol performed on it; + // it will be left in the non member/listener state, and packets will never + // be sent for it. + AllNodesAddress tcpip.Address +} + +// MulticastGroupProtocol is a multicast group protocol whose core state machine +// can be represented by GenericMulticastProtocolState. +type MulticastGroupProtocol interface { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled() bool + + // SendReport sends a multicast report for the specified group address. + // + // Returns false if the caller should queue the report to be sent later. Note, + // returning false does not mean that the receiver hit an error. + SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error) + + // SendLeave sends a multicast leave for the specified group address. + SendLeave(groupAddress tcpip.Address) tcpip.Error +} + +// GenericMulticastProtocolState is the per interface generic multicast protocol +// state. +// +// There is actually no protocol named "Generic Multicast Protocol". Instead, +// the term used to refer to a generic multicast protocol that applies to both +// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state +// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. +// +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// +// GenericMulticastProtocolState.Init MUST be called before calling any of +// the methods on GenericMulticastProtocolState. +// +// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the +// multicast group protocol is disabled so that leave messages may be sent. +type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + + opts GenericMulticastProtocolOptions + + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState + + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex +} + +// Init initializes the Generic Multicast Protocol state. +// +// Must only be called once for the lifetime of g; Init will panic if it is +// called twice. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +// +// Note: the methods on opts.Protocol will always be called while protocolMU is +// held. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + + *g = GenericMulticastProtocolState{ + opts: opts, + memberships: make(map[tcpip.Address]multicastGroupState), + protocolMU: protocolMU, + } +} + +// MakeAllNonMemberLocked transitions all groups to the non-member state. +// +// The groups will still be considered joined locally. +// +// MUST be called when the multicast group protocol is disabled. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { + if !g.opts.Protocol.Enabled() { + return + } + + for groupAddress, info := range g.memberships { + g.transitionToNonMemberLocked(groupAddress, &info) + g.memberships[groupAddress] = info + } +} + +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. +// +// Must only be called after calling MakeAllNonMember as a group should not be +// initialized while it is not in the non-member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { + if !g.opts.Protocol.Enabled() { + return + } + + for groupAddress, info := range g.memberships { + g.initializeNewMemberLocked(groupAddress, &info) + g.memberships[groupAddress] = info + } +} + +// SendQueuedReportsLocked attempts to send reports for groups that failed to +// send reports during their last attempt. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { + for groupAddress, info := range g.memberships { + switch info.state { + case nonMember, delayingMember, idleMember: + case pendingMember: + // pendingMembers failed to send their initial unsolicited report so try + // to send the report and queue the extra unsolicited reports. + g.maybeSendInitialReportLocked(groupAddress, &info) + case queuedDelayingMember: + // queuedDelayingMembers failed to send their delayed reports so try to + // send the report and transition them to the idle state. + g.maybeSendDelayedReportLocked(groupAddress, &info) + default: + panic(fmt.Sprintf("unrecognized host state = %d", info.state)) + } + g.memberships[groupAddress] = info + } +} + +// JoinGroupLocked handles joining a new group. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) { + if info, ok := g.memberships[groupAddress]; ok { + // The group has already been joined. + info.joins++ + g.memberships[groupAddress] = info + return + } + + info := multicastGroupState{ + // Since we just joined the group, its count is 1. + joins: 1, + // The state will be updated below, if required. + state: nonMember, + lastToSendReport: false, + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + if !g.opts.Protocol.Enabled() { + panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress)) + } + + info, ok := g.memberships[groupAddress] + if !ok { + panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) + } + + g.maybeSendDelayedReportLocked(groupAddress, &info) + g.memberships[groupAddress] = info + }), + } + + if g.opts.Protocol.Enabled() { + g.initializeNewMemberLocked(groupAddress, &info) + } + + g.memberships[groupAddress] = info +} + +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] + return ok +} + +// LeaveGroupLocked handles leaving the group. +// +// Returns false if the group is not currently joined. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] + if !ok { + return false + } + + if info.joins == 0 { + panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress)) + } + info.joins-- + if info.joins != 0 { + // If we still have outstanding joins, then do nothing further. + g.memberships[groupAddress] = info + return true + } + + g.transitionToNonMemberLocked(groupAddress, &info) + delete(g.memberships, groupAddress) + return true +} + +// HandleQueryLocked handles a query message with the specified maximum response +// time. +// +// If the group address is unspecified, then reports will be scheduled for all +// joined groups. +// +// Report(s) will be scheduled to be sent after a random duration between 0 and +// the maximum response time. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { + if !g.opts.Protocol.Enabled() { + return + } + + // As per RFC 2236 section 2.4 (for IGMPv2), + // + // In a Membership Query message, the group address field is set to zero + // when sending a General Query, and set to the group address being + // queried when sending a Group-Specific Query. + // + // As per RFC 2710 section 3.6 (for MLDv1), + // + // In a Query message, the Multicast Address field is set to zero when + // sending a General Query, and set to a specific IPv6 multicast address + // when sending a Multicast-Address-Specific Query. + if groupAddress.Unspecified() { + // This is a general query as the group address is unspecified. + for groupAddress, info := range g.memberships { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.memberships[groupAddress] = info + } + } else if info, ok := g.memberships[groupAddress]; ok { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.memberships[groupAddress] = info + } +} + +// HandleReportLocked handles a report message. +// +// If the report is for a joined group, any active delayed report will be +// cancelled and the host state for the group transitions to idle. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { + if !g.opts.Protocol.Enabled() { + return + } + + // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), + // + // If the host receives another host's Report (version 1 or 2) while it has + // a timer running, it stops its timer for the specified group and does not + // send a Report + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // If a node receives another node's Report from an interface for a + // multicast address while it has a timer running for that same address + // on that interface, it stops its timer and does not send a Report for + // that address, thus suppressing duplicate reports on the link. + if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { + info.cancelDelayedReportJob() + info.lastToSendReport = false + info.state = idleMember + g.memberships[groupAddress] = info + } +} + +// initializeNewMemberLocked initializes a new group membership. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != nonMember { + panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) + } + + info.lastToSendReport = false + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + info.state = idleMember + return + } + + info.state = pendingMember + g.maybeSendInitialReportLocked(groupAddress, info) +} + +// maybeSendInitialReportLocked attempts to start transmission of the initial +// set of reports after newly joining a group. +// +// Host must be in pending member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != pendingMember { + panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) + } + + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a host joins a multicast group, it should immediately transmit an + // unsolicited Version 2 Membership Report for that group" ... "it is + // recommended that it be repeated". + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node starts listening to a multicast address on an interface, + // it should immediately transmit an unsolicited Report for that address + // on that interface, in case it is the first listener on the link. To + // cover the possibility of the initial Report being lost or damaged, it + // is recommended that it be repeated once or twice after short delays + // [Unsolicited Report Interval]. + // + // TODO(gvisor.dev/issue/4901): Support a configurable number of initial + // unsolicited reports. + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + } +} + +// maybeSendDelayedReportLocked attempts to send the delayed report. +// +// Host must be in pending, delaying or queued delaying member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if !info.state.isDelayingMember() { + panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) + } + + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + info.state = idleMember + } else { + info.state = queuedDelayingMember + } +} + +// maybeSendLeave attempts to send a leave message. +func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { + if !g.opts.Protocol.Enabled() || !lastToSendReport { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // Okay to ignore the error here as if packet write failed, the multicast + // routers will eventually drop our membership anyways. If the interface is + // being disabled or removed, the generic multicast protocol's should be + // cleared eventually. + // + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a router receives a Report, it adds the group being reported to + // the list of multicast group memberships on the network on which it + // received the Report and sets the timer for the membership to the + // [Group Membership Interval]. Repeated Reports refresh the timer. If + // no Reports are received for a particular group before this timer has + // expired, the router assumes that the group has no local members and + // that it need not forward remotely-originated multicasts for that + // group onto the attached network. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // When a router receives a Report from a link, if the reported address + // is not already present in the router's list of multicast address + // having listeners on that link, the reported address is added to the + // list, its timer is set to [Multicast Listener Interval], and its + // appearance is made known to the router's multicast routing component. + // If a Report is received for a multicast address that is already + // present in the router's list, the timer for that address is reset to + // [Multicast Listener Interval]. If an address's timer expires, it is + // assumed that there are no longer any listeners for that address + // present on the link, so it is deleted from the list and its + // disappearance is made known to the multicast routing component. + // + // The requirement to send a leave message is also optional (it MAY be + // skipped): + // + // As per RFC 2236 section 6 page 8 (for IGMPv2), + // + // "send leave" for the group on the interface. If the interface + // state says the Querier is running IGMPv1, this action SHOULD be + // skipped. If the flag saying we were the last host to report is + // cleared, this action MAY be skipped. The Leave Message is sent to + // the ALL-ROUTERS group (224.0.0.2). + // + // As per RFC 2710 section 5 page 8 (for MLDv1), + // + // "send done" for the address on the interface. If the flag saying + // we were the last node to report is cleared, this action MAY be + // skipped. The Done message is sent to the link-scope all-routers + // address (FF02::2). + _ = g.opts.Protocol.SendLeave(groupAddress) +} + +// transitionToNonMemberLocked transitions the given multicast group the the +// non-member/listener state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state == nonMember { + return + } + + info.cancelDelayedReportJob() + g.maybeSendLeave(groupAddress, info.lastToSendReport) + info.lastToSendReport = false + info.state = nonMember +} + +// setDelayTimerForAddressRLocked sets timer to send a delay report. +// +// Precondition: g.protocolMU MUST be read locked. +func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { + if info.state == nonMember { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // If a timer for the group is already unning, it is reset to the random + // value only if the requested Max Response Time is less than the remaining + // value of the running timer. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // If a timer for any address is already running, it is reset to the new + // random value only if the requested Maximum Response Delay is less than + // the remaining value of the running timer. + now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds()) + if info.state == delayingMember { + if info.delayedReportJobFiresAt.IsZero() { + panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress)) + } + + if info.delayedReportJobFiresAt.Sub(now) <= maxResponseTime { + // The timer is scheduled to fire before the maximum response time so we + // leave our timer as is. + return + } + } + + info.state = delayingMember + info.cancelDelayedReportJob() + maxResponseTime = g.calculateDelayTimerDuration(maxResponseTime) + info.delayedReportJob.Schedule(maxResponseTime) + info.delayedReportJobFiresAt = now.Add(maxResponseTime) +} + +// calculateDelayTimerDuration returns a random time between (0, maxRespTime]. +func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration { + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // When a host receives a Group-Specific Query, it sets a delay timer to a + // random value selected from the range (0, Max Response Time]... + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node receives a Multicast-Address-Specific Query, if it is + // listening to the queried Multicast Address on the interface from + // which the Query was received, it sets a delay timer for that address + // to a random value selected from the range [0, Maximum Response Delay], + // as above. + if maxRespTime == 0 { + return 0 + } + return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime))) +} diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go new file mode 100644 index 000000000..381460c82 --- /dev/null +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -0,0 +1,805 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "math/rand" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" +) + +const maxUnsolicitedReportDelay = time.Second + +var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) + +type mockMulticastGroupProtocolProtectedFields struct { + sync.RWMutex + + genericMulticastGroup ip.GenericMulticastProtocolState + sendReportGroupAddrCount map[tcpip.Address]int + sendLeaveGroupAddrCount map[tcpip.Address]int + makeQueuePackets bool + disabled bool +} + +type mockMulticastGroupProtocol struct { + t *testing.T + + mu mockMulticastGroupProtocolProtectedFields +} + +func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { + m.mu.Lock() + defer m.mu.Unlock() + m.initLocked() + opts.Protocol = m + m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) +} + +func (m *mockMulticastGroupProtocol) initLocked() { + m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) + m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +} + +func (m *mockMulticastGroupProtocol) setEnabled(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.disabled = !v +} + +func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.makeQueuePackets = v +} + +func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.JoinGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleReportLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) +} + +func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) +} + +func (m *mockMulticastGroupProtocol) makeAllNonMember() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.MakeAllNonMemberLocked() +} + +func (m *mockMulticastGroupProtocol) initializeGroups() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.InitializeGroupsLocked() +} + +func (m *mockMulticastGroupProtocol) sendQueuedReports() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.SendQueuedReportsLocked() +} + +// Enabled implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be read locked. +func (m *mockMulticastGroupProtocol) Enabled() bool { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") + } + + return !m.mu.disabled +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + + m.mu.sendReportGroupAddrCount[groupAddress]++ + return !m.mu.makeQueuePackets, nil +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. +func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + + m.mu.sendLeaveGroupAddrCount[groupAddress]++ + return nil +} + +func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendReportGroupAddrCount := make(map[tcpip.Address]int) + for _, a := range sendReportGroupAddresses { + sendReportGroupAddrCount[a] = 1 + } + + sendLeaveGroupAddrCount := make(map[tcpip.Address]int) + for _, a := range sendLeaveGroupAddresses { + sendLeaveGroupAddrCount[a] = 1 + } + + diff := cmp.Diff( + &mockMulticastGroupProtocol{ + mu: mockMulticastGroupProtocolProtectedFields{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, + }, + m, + cmp.AllowUnexported(mockMulticastGroupProtocol{}), + cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), + // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t + cmp.FilterPath( + func(p cmp.Path) bool { + switch p.Last().String() { + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": + return true + } + return false + }, + cmp.Ignore(), + ), + ) + m.initLocked() + return diff +} + +func TestJoinGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendReports bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendReports: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendReports: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(0)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + // Joining a group should send a report immediately and another after + // a random interval between 0 and the maximum unsolicited report delay. + mgp.joinGroup(test.addr) + if test.shouldSendReports { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLeaveGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendMessages bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendMessages: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendMessages: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(1)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + mgp.joinGroup(test.addr) + if test.shouldSendMessages { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Leaving a group should send a leave report immediately and cancel any + // delayed reports. + { + + if !mgp.leaveGroup(test.addr) { + t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) + } + } + if test.shouldSendMessages { + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleReport(t *testing.T) { + tests := []struct { + name string + reportAddr tcpip.Address + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + reportAddr: "", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + reportAddr: "\x00", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + reportAddr: addr1, + expectReportsFor: []tcpip.Address{addr2}, + }, + { + name: "Specified all-nodes", + reportAddr: addr3, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified other", + reportAddr: addr4, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(2)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report for a group we have a timer scheduled for should + // cancel our delayed report timer for the group. + mgp.handleReport(test.reportAddr) + if len(test.expectReportsFor) != 0 { + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleQuery(t *testing.T) { + tests := []struct { + name string + queryAddr tcpip.Address + maxDelay time.Duration + expectQueriedReportsFor []tcpip.Address + expectDelayedReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + queryAddr: "", + maxDelay: 0, + expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, + expectDelayedReportsFor: nil, + }, + { + name: "Unpecified any", + queryAddr: "\x00", + maxDelay: 1, + expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, + expectDelayedReportsFor: nil, + }, + { + name: "Specified", + queryAddr: addr1, + maxDelay: 2, + expectQueriedReportsFor: []tcpip.Address{addr1}, + expectDelayedReportsFor: []tcpip.Address{addr2}, + }, + { + name: "Specified all-nodes", + queryAddr: addr3, + maxDelay: 3, + expectQueriedReportsFor: nil, + expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified other", + queryAddr: addr4, + maxDelay: 4, + expectQueriedReportsFor: nil, + expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should make us reschedule our delayed report timer + // to some time within the new max response delay. + mgp.handleQuery(test.queryAddr, test.maxDelay) + clock.Advance(test.maxDelay) + if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The groups that were not affected by the query should still send a + // report after the max unsolicited report delay. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestJoinCount(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + MaxUnsolicitedReportDelay: time.Second, + }) + + // Set the join count to 2 for a group. + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + // Only the first join should trigger a report to be sent. + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Group should still be considered joined after leaving once. + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) + } + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + // A leave report should only be sent once the join count reaches 0. + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Leaving once more should actually remove us from the group. + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Group should no longer be joined so we should not have anything to + // leave. + if mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestMakeAllNonMemberAndInitialize(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should send the leave reports for each but still consider them locally + // joined. + mgp.makeAllNonMember() + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + for _, group := range []tcpip.Address{addr1, addr2, addr3} { + if !mgp.isLocallyJoined(group) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) + } + } + + // Should send the initial set of unsolcited reports. + mgp.initializeGroups() + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +// TestGroupStateNonMember tests that groups do not send packets when in the +// non-member state, but are still considered locally joined. +func TestGroupStateNonMember(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + mgp.setEnabled(false) + + // Joining groups should not send any reports. + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should not send any reports. + mgp.handleQuery(addr1, time.Nanosecond) + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Nanosecond) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Leaving groups should not send any leave messages. + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestQueuedPackets(t *testing.T) { + clock := faketime.NewManualClock() + mgp := mockMulticastGroupProtocol{t: t} + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + // Joining should trigger a SendReport, but mgp should report that we did not + // send the packet. + mgp.setQueuePackets(true) + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report timer should have been cancelled since we did not send + // the initial report earlier. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to successfully send the report. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send (we should be idle). + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query but mock being unable to send reports again. + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to send reports again - we should have a packet queued to + // send. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query again, but mock being unable to send reports. + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report should should transition us into the idle member state, + // even if we had a packet queued. We should no longer have any packets to + // send. + mgp.handleReport(addr1) + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // When we fail to send the initial set of reports, incoming reports should + // not affect a newly joined group's reports from being sent. + mgp.setQueuePackets(true) + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.handleReport(addr2) + // Attempting to send queued reports while still unable to send reports should + // not change the host state. + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Mock being able to successfully send the report. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go new file mode 100644 index 000000000..898f8b356 --- /dev/null +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import "gvisor.dev/gvisor/pkg/tcpip" + +// LINT.IfChange(MultiCounterIPStats) + +// MultiCounterIPStats holds IP statistics, each counter may have several +// versions. +type MultiCounterIPStats struct { + // PacketsReceived is the total number of IP packets received from the link + // layer. + PacketsReceived tcpip.MultiCounterStat + + // DisabledPacketsReceived is the total number of IP packets received from the + // link layer when the IP layer is disabled. + DisabledPacketsReceived tcpip.MultiCounterStat + + // InvalidDestinationAddressesReceived is the total number of IP packets + // received with an unknown or invalid destination address. + InvalidDestinationAddressesReceived tcpip.MultiCounterStat + + // InvalidSourceAddressesReceived is the total number of IP packets received + // with a source address that should never have been received on the wire. + InvalidSourceAddressesReceived tcpip.MultiCounterStat + + // PacketsDelivered is the total number of incoming IP packets that are + // successfully delivered to the transport layer. + PacketsDelivered tcpip.MultiCounterStat + + // PacketsSent is the total number of IP packets sent via WritePacket. + PacketsSent tcpip.MultiCounterStat + + // OutgoingPacketErrors is the total number of IP packets which failed to + // write to a link-layer endpoint. + OutgoingPacketErrors tcpip.MultiCounterStat + + // MalformedPacketsReceived is the total number of IP Packets that were + // dropped due to the IP packet header failing validation checks. + MalformedPacketsReceived tcpip.MultiCounterStat + + // MalformedFragmentsReceived is the total number of IP Fragments that were + // dropped due to the fragment failing validation checks. + MalformedFragmentsReceived tcpip.MultiCounterStat + + // IPTablesPreroutingDropped is the total number of IP packets dropped in the + // Prerouting chain. + IPTablesPreroutingDropped tcpip.MultiCounterStat + + // IPTablesInputDropped is the total number of IP packets dropped in the Input + // chain. + IPTablesInputDropped tcpip.MultiCounterStat + + // IPTablesOutputDropped is the total number of IP packets dropped in the + // Output chain. + IPTablesOutputDropped tcpip.MultiCounterStat + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived tcpip.MultiCounterStat + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived tcpip.MultiCounterStat + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived tcpip.MultiCounterStat +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { + m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) + m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) + m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) + m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered) + m.PacketsSent.Init(a.PacketsSent, b.PacketsSent) + m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors) + m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived) + m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) + m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) + m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) + m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) + m.OptionTSReceived.Init(a.OptionTSReceived, b.OptionTSReceived) + m.OptionRRReceived.Init(a.OptionRRReceived, b.OptionRRReceived) + m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) +} + +// LINT.ThenChange(:MultiCounterIPStats, ../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD new file mode 100644 index 000000000..1c4f583c7 --- /dev/null +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -0,0 +1,23 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + srcs = [ + "testutil.go", + "testutil_unsafe.go", + ], + visibility = [ + "//pkg/tcpip/network/arp:__pkg__", + "//pkg/tcpip/network/internal/fragmentation:__pkg__", + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go new file mode 100644 index 000000000..f5fa77b65 --- /dev/null +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -0,0 +1,197 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil defines types and functions used to test Network Layer +// functionality such as IP fragmentation. +package testutil + +import ( + "fmt" + "math/rand" + "reflect" + "strings" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// MockLinkEndpoint is an endpoint used for testing, it stores packets written +// to it and can mock errors. +type MockLinkEndpoint struct { + // WrittenPackets is where packets written to the endpoint are stored. + WrittenPackets []*stack.PacketBuffer + + mtu uint32 + err tcpip.Error + allowPackets int +} + +// NewMockLinkEndpoint creates a new MockLinkEndpoint. +// +// err is the error that will be returned once allowPackets packets are written +// to the endpoint. +func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { + return &MockLinkEndpoint{ + mtu: mtu, + err: err, + allowPackets: allowPackets, + } +} + +// MTU implements LinkEndpoint.MTU. +func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + var n int + + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { + return n, err + } + n++ + } + + return n, nil +} + +// Attach implements LinkEndpoint.Attach. +func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*MockLinkEndpoint) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*MockLinkEndpoint) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// MakeRandPkt generates a randomized packet. transportHeaderLength indicates +// how many random bytes will be copied in the Transport Header. +// extraHeaderReserveLength indicates how much extra space will be reserved for +// the other headers. The payload is made from Views of the sizes listed in +// viewSizes. +func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { + var views buffer.VectorisedView + + for _, s := range viewSizes { + newView := buffer.NewView(s) + if _, err := rand.Read(newView); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + views.AppendView(newView) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, + Data: views, + }) + pkt.NetworkProtocolNumber = proto + if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt +} + +func checkFieldCounts(ref, multi reflect.Value) error { + refTypeName := ref.Type().Name() + multiTypeName := multi.Type().Name() + refNumField := ref.NumField() + multiNumField := multi.NumField() + + if refNumField != multiNumField { + return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) + } + + return nil +} + +func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { + s, ok := ref.Addr().Interface().(**tcpip.StatCounter) + if !ok { + return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) + } + + // The field names are expected to match (case insensitive). + if !strings.EqualFold(refName, multiName) { + return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) + } + + base := (*s).Value() + m.Increment() + if (*s).Value() != base+1 { + return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) + } + + return nil +} + +// ValidateMultiCounterStats verifies that every counter stored in multi is +// correctly tracking its counterpart in the given counters. +func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { + for _, c := range counters { + if err := checkFieldCounts(c, multi); err != nil { + return err + } + } + + for i := 0; i < multi.NumField(); i++ { + multiName := multi.Type().Field(i).Name + multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) + + if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { + for _, c := range counters { + if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { + return err + } + } + } else { + var countersNextField []reflect.Value + for _, c := range counters { + countersNextField = append(countersNextField, c.Field(i)) + } + if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go new file mode 100644 index 000000000..5ff764800 --- /dev/null +++ b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go @@ -0,0 +1,26 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "reflect" + "unsafe" +) + +// unsafeExposeUnexportedFields takes a Value and returns a version of it in +// which even unexported fields can be read and written. +func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { + return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() +} diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD deleted file mode 100644 index 411bca25d..000000000 --- a/pkg/tcpip/network/ip/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ip", - srcs = [ - "generic_multicast_protocol.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - ], -) - -go_test( - name = "ip_test", - size = "small", - srcs = ["generic_multicast_protocol_test.go"], - deps = [ - ":ip", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/faketime", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go deleted file mode 100644 index b9f129728..000000000 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ /dev/null @@ -1,696 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package ip holds IPv4/IPv6 common utilities. -package ip - -import ( - "fmt" - "math/rand" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -// hostState is the state a host may be in for a multicast group. -type hostState int - -// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 -// (RFC 2710 section 5). Even though the states are generic across both IGMPv2 -// and MLDv1, IGMPv2 terminology will be used. -// -// ______________receive query______________ -// | | -// | _____send or receive report_____ | -// | | | | -// V | V | -// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | -// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - -// +-------+ +-----------+ +------------+ +-------------------+ +--------+ -// | ^ | ^ | ^ | ^ -// | | | | | | | | -// ---------- ------- ---------- ------------- -// initialize new send inital fail to send send or receive -// group membership report delayed report report -// -// Not shown in the diagram above, but any state may transition into the non -// member state when a group is left. -const ( - // nonMember is the "'Non-Member' state, when the host does not belong to the - // group on the interface. This is the initial state for all memberships on - // all network interfaces; it requires no storage in the host." - // - // 'Non-Listener' is the MLDv1 term used to describe this state. - // - // This state is used to keep track of groups that have been joined locally, - // but without advertising the membership to the network. - nonMember hostState = iota - - // pendingMember is a newly joined member that is waiting to successfully send - // the initial set of reports. - // - // This is not an RFC defined state; it is an implementation specific state to - // track that the initial report needs to be sent. - // - // MAY NOT transition to the idle member state from this state. - pendingMember - - // delayingMember is the "'Delaying Member' state, when the host belongs to - // the group on the interface and has a report delay timer running for that - // membership." - // - // 'Delaying Listener' is the MLDv1 term used to describe this state. - delayingMember - - // queuedDelayingMember is a delayingMember that failed to send a report after - // its delayed report timer fired. Hosts in this state are waiting to attempt - // retransmission of the delayed report. - // - // This is not an RFC defined state; it is an implementation specific state to - // track that the delayed report needs to be sent. - // - // May transition to idle member if a report is received for a group. - queuedDelayingMember - - // idleMember is the "Idle Member" state, when the host belongs to the group - // on the interface and does not have a report delay timer running for that - // membership. - // - // 'Idle Listener' is the MLDv1 term used to describe this state. - idleMember -) - -func (s hostState) isDelayingMember() bool { - switch s { - case nonMember, pendingMember, idleMember: - return false - case delayingMember, queuedDelayingMember: - return true - default: - panic(fmt.Sprintf("unrecognized host state = %d", s)) - } -} - -// multicastGroupState holds the Generic Multicast Protocol state for a -// multicast group. -type multicastGroupState struct { - // joins is the number of times the group has been joined. - joins uint64 - - // state holds the host's state for the group. - state hostState - - // lastToSendReport is true if we sent the last report for the group. It is - // used to track whether there are other hosts on the subnet that are also - // members of the group. - // - // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page - // 8 for MLDv1. - lastToSendReport bool - - // delayedReportJob is used to delay sending responses to membership report - // messages in order to reduce duplicate reports from multiple hosts on the - // interface. - // - // Must not be nil. - delayedReportJob *tcpip.Job - - // delyedReportJobFiresAt is the time when the delayed report job will fire. - // - // A zero value indicates that the job is not scheduled. - delayedReportJobFiresAt time.Time -} - -func (m *multicastGroupState) cancelDelayedReportJob() { - m.delayedReportJob.Cancel() - m.delayedReportJobFiresAt = time.Time{} -} - -// GenericMulticastProtocolOptions holds options for the generic multicast -// protocol. -type GenericMulticastProtocolOptions struct { - // Rand is the source of random numbers. - Rand *rand.Rand - - // Clock is the clock used to create timers. - Clock tcpip.Clock - - // Protocol is the implementation of the variant of multicast group protocol - // in use. - Protocol MulticastGroupProtocol - - // MaxUnsolicitedReportDelay is the maximum amount of time to wait between - // transmitting unsolicited reports. - // - // Unsolicited reports are transmitted when a group is newly joined. - MaxUnsolicitedReportDelay time.Duration - - // AllNodesAddress is a multicast address that all nodes on a network should - // be a member of. - // - // This address will not have the generic multicast protocol performed on it; - // it will be left in the non member/listener state, and packets will never - // be sent for it. - AllNodesAddress tcpip.Address -} - -// MulticastGroupProtocol is a multicast group protocol whose core state machine -// can be represented by GenericMulticastProtocolState. -type MulticastGroupProtocol interface { - // Enabled indicates whether the generic multicast protocol will be - // performed. - // - // When enabled, the protocol may transmit report and leave messages when - // joining and leaving multicast groups respectively, and handle incoming - // packets. - // - // When disabled, the protocol will still keep track of locally joined groups, - // it just won't transmit and handle packets, or update groups' state. - Enabled() bool - - // SendReport sends a multicast report for the specified group address. - // - // Returns false if the caller should queue the report to be sent later. Note, - // returning false does not mean that the receiver hit an error. - SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error) - - // SendLeave sends a multicast leave for the specified group address. - SendLeave(groupAddress tcpip.Address) tcpip.Error -} - -// GenericMulticastProtocolState is the per interface generic multicast protocol -// state. -// -// There is actually no protocol named "Generic Multicast Protocol". Instead, -// the term used to refer to a generic multicast protocol that applies to both -// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state -// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. -// -// Callers must synchronize accesses to the generic multicast protocol state; -// GenericMulticastProtocolState obtains no locks in any of its methods. The -// only exception to this is GenericMulticastProtocolState's timer/job callbacks -// which will obtain the lock provided to the GenericMulticastProtocolState when -// it is initialized. -// -// GenericMulticastProtocolState.Init MUST be called before calling any of -// the methods on GenericMulticastProtocolState. -// -// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the -// multicast group protocol is disabled so that leave messages may be sent. -type GenericMulticastProtocolState struct { - // Do not allow overwriting this state. - _ sync.NoCopy - - opts GenericMulticastProtocolOptions - - // memberships holds group addresses and their associated state. - memberships map[tcpip.Address]multicastGroupState - - // protocolMU is the mutex used to protect the protocol. - protocolMU *sync.RWMutex -} - -// Init initializes the Generic Multicast Protocol state. -// -// Must only be called once for the lifetime of g; Init will panic if it is -// called twice. -// -// The GenericMulticastProtocolState will only grab the lock when timers/jobs -// fire. -// -// Note: the methods on opts.Protocol will always be called while protocolMU is -// held. -func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { - if g.memberships != nil { - panic("attempted to initialize generic membership protocol state twice") - } - - *g = GenericMulticastProtocolState{ - opts: opts, - memberships: make(map[tcpip.Address]multicastGroupState), - protocolMU: protocolMU, - } -} - -// MakeAllNonMemberLocked transitions all groups to the non-member state. -// -// The groups will still be considered joined locally. -// -// MUST be called when the multicast group protocol is disabled. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { - if !g.opts.Protocol.Enabled() { - return - } - - for groupAddress, info := range g.memberships { - g.transitionToNonMemberLocked(groupAddress, &info) - g.memberships[groupAddress] = info - } -} - -// InitializeGroupsLocked initializes each group, as if they were newly joined -// but without affecting the groups' join count. -// -// Must only be called after calling MakeAllNonMember as a group should not be -// initialized while it is not in the non-member state. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { - if !g.opts.Protocol.Enabled() { - return - } - - for groupAddress, info := range g.memberships { - g.initializeNewMemberLocked(groupAddress, &info) - g.memberships[groupAddress] = info - } -} - -// SendQueuedReportsLocked attempts to send reports for groups that failed to -// send reports during their last attempt. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { - for groupAddress, info := range g.memberships { - switch info.state { - case nonMember, delayingMember, idleMember: - case pendingMember: - // pendingMembers failed to send their initial unsolicited report so try - // to send the report and queue the extra unsolicited reports. - g.maybeSendInitialReportLocked(groupAddress, &info) - case queuedDelayingMember: - // queuedDelayingMembers failed to send their delayed reports so try to - // send the report and transition them to the idle state. - g.maybeSendDelayedReportLocked(groupAddress, &info) - default: - panic(fmt.Sprintf("unrecognized host state = %d", info.state)) - } - g.memberships[groupAddress] = info - } -} - -// JoinGroupLocked handles joining a new group. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) { - if info, ok := g.memberships[groupAddress]; ok { - // The group has already been joined. - info.joins++ - g.memberships[groupAddress] = info - return - } - - info := multicastGroupState{ - // Since we just joined the group, its count is 1. - joins: 1, - // The state will be updated below, if required. - state: nonMember, - lastToSendReport: false, - delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { - if !g.opts.Protocol.Enabled() { - panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress)) - } - - info, ok := g.memberships[groupAddress] - if !ok { - panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) - } - - g.maybeSendDelayedReportLocked(groupAddress, &info) - g.memberships[groupAddress] = info - }), - } - - if g.opts.Protocol.Enabled() { - g.initializeNewMemberLocked(groupAddress, &info) - } - - g.memberships[groupAddress] = info -} - -// IsLocallyJoinedRLocked returns true if the group is locally joined. -// -// Precondition: g.protocolMU must be read locked. -func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { - _, ok := g.memberships[groupAddress] - return ok -} - -// LeaveGroupLocked handles leaving the group. -// -// Returns false if the group is not currently joined. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { - info, ok := g.memberships[groupAddress] - if !ok { - return false - } - - if info.joins == 0 { - panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress)) - } - info.joins-- - if info.joins != 0 { - // If we still have outstanding joins, then do nothing further. - g.memberships[groupAddress] = info - return true - } - - g.transitionToNonMemberLocked(groupAddress, &info) - delete(g.memberships, groupAddress) - return true -} - -// HandleQueryLocked handles a query message with the specified maximum response -// time. -// -// If the group address is unspecified, then reports will be scheduled for all -// joined groups. -// -// Report(s) will be scheduled to be sent after a random duration between 0 and -// the maximum response time. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { - if !g.opts.Protocol.Enabled() { - return - } - - // As per RFC 2236 section 2.4 (for IGMPv2), - // - // In a Membership Query message, the group address field is set to zero - // when sending a General Query, and set to the group address being - // queried when sending a Group-Specific Query. - // - // As per RFC 2710 section 3.6 (for MLDv1), - // - // In a Query message, the Multicast Address field is set to zero when - // sending a General Query, and set to a specific IPv6 multicast address - // when sending a Multicast-Address-Specific Query. - if groupAddress.Unspecified() { - // This is a general query as the group address is unspecified. - for groupAddress, info := range g.memberships { - g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.memberships[groupAddress] = info - } - } else if info, ok := g.memberships[groupAddress]; ok { - g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.memberships[groupAddress] = info - } -} - -// HandleReportLocked handles a report message. -// -// If the report is for a joined group, any active delayed report will be -// cancelled and the host state for the group transitions to idle. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { - if !g.opts.Protocol.Enabled() { - return - } - - // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), - // - // If the host receives another host's Report (version 1 or 2) while it has - // a timer running, it stops its timer for the specified group and does not - // send a Report - // - // As per RFC 2710 section 4 page 6 (for MLDv1), - // - // If a node receives another node's Report from an interface for a - // multicast address while it has a timer running for that same address - // on that interface, it stops its timer and does not send a Report for - // that address, thus suppressing duplicate reports on the link. - if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { - info.cancelDelayedReportJob() - info.lastToSendReport = false - info.state = idleMember - g.memberships[groupAddress] = info - } -} - -// initializeNewMemberLocked initializes a new group membership. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { - if info.state != nonMember { - panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) - } - - info.lastToSendReport = false - - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. - info.state = idleMember - return - } - - info.state = pendingMember - g.maybeSendInitialReportLocked(groupAddress, info) -} - -// maybeSendInitialReportLocked attempts to start transmission of the initial -// set of reports after newly joining a group. -// -// Host must be in pending member state. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { - if info.state != pendingMember { - panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) - } - - // As per RFC 2236 section 3 page 5 (for IGMPv2), - // - // When a host joins a multicast group, it should immediately transmit an - // unsolicited Version 2 Membership Report for that group" ... "it is - // recommended that it be repeated". - // - // As per RFC 2710 section 4 page 6 (for MLDv1), - // - // When a node starts listening to a multicast address on an interface, - // it should immediately transmit an unsolicited Report for that address - // on that interface, in case it is the first listener on the link. To - // cover the possibility of the initial Report being lost or damaged, it - // is recommended that it be repeated once or twice after short delays - // [Unsolicited Report Interval]. - // - // TODO(gvisor.dev/issue/4901): Support a configurable number of initial - // unsolicited reports. - sent, err := g.opts.Protocol.SendReport(groupAddress) - if err == nil && sent { - info.lastToSendReport = true - g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) - } -} - -// maybeSendDelayedReportLocked attempts to send the delayed report. -// -// Host must be in pending, delaying or queued delaying member state. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { - if !info.state.isDelayingMember() { - panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) - } - - sent, err := g.opts.Protocol.SendReport(groupAddress) - if err == nil && sent { - info.lastToSendReport = true - info.state = idleMember - } else { - info.state = queuedDelayingMember - } -} - -// maybeSendLeave attempts to send a leave message. -func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { - if !g.opts.Protocol.Enabled() || !lastToSendReport { - return - } - - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. - return - } - - // Okay to ignore the error here as if packet write failed, the multicast - // routers will eventually drop our membership anyways. If the interface is - // being disabled or removed, the generic multicast protocol's should be - // cleared eventually. - // - // As per RFC 2236 section 3 page 5 (for IGMPv2), - // - // When a router receives a Report, it adds the group being reported to - // the list of multicast group memberships on the network on which it - // received the Report and sets the timer for the membership to the - // [Group Membership Interval]. Repeated Reports refresh the timer. If - // no Reports are received for a particular group before this timer has - // expired, the router assumes that the group has no local members and - // that it need not forward remotely-originated multicasts for that - // group onto the attached network. - // - // As per RFC 2710 section 4 page 5 (for MLDv1), - // - // When a router receives a Report from a link, if the reported address - // is not already present in the router's list of multicast address - // having listeners on that link, the reported address is added to the - // list, its timer is set to [Multicast Listener Interval], and its - // appearance is made known to the router's multicast routing component. - // If a Report is received for a multicast address that is already - // present in the router's list, the timer for that address is reset to - // [Multicast Listener Interval]. If an address's timer expires, it is - // assumed that there are no longer any listeners for that address - // present on the link, so it is deleted from the list and its - // disappearance is made known to the multicast routing component. - // - // The requirement to send a leave message is also optional (it MAY be - // skipped): - // - // As per RFC 2236 section 6 page 8 (for IGMPv2), - // - // "send leave" for the group on the interface. If the interface - // state says the Querier is running IGMPv1, this action SHOULD be - // skipped. If the flag saying we were the last host to report is - // cleared, this action MAY be skipped. The Leave Message is sent to - // the ALL-ROUTERS group (224.0.0.2). - // - // As per RFC 2710 section 5 page 8 (for MLDv1), - // - // "send done" for the address on the interface. If the flag saying - // we were the last node to report is cleared, this action MAY be - // skipped. The Done message is sent to the link-scope all-routers - // address (FF02::2). - _ = g.opts.Protocol.SendLeave(groupAddress) -} - -// transitionToNonMemberLocked transitions the given multicast group the the -// non-member/listener state. -// -// Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { - if info.state == nonMember { - return - } - - info.cancelDelayedReportJob() - g.maybeSendLeave(groupAddress, info.lastToSendReport) - info.lastToSendReport = false - info.state = nonMember -} - -// setDelayTimerForAddressRLocked sets timer to send a delay report. -// -// Precondition: g.protocolMU MUST be read locked. -func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { - if info.state == nonMember { - return - } - - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. - return - } - - // As per RFC 2236 section 3 page 3 (for IGMPv2), - // - // If a timer for the group is already unning, it is reset to the random - // value only if the requested Max Response Time is less than the remaining - // value of the running timer. - // - // As per RFC 2710 section 4 page 5 (for MLDv1), - // - // If a timer for any address is already running, it is reset to the new - // random value only if the requested Maximum Response Delay is less than - // the remaining value of the running timer. - now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds()) - if info.state == delayingMember { - if info.delayedReportJobFiresAt.IsZero() { - panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress)) - } - - if info.delayedReportJobFiresAt.Sub(now) <= maxResponseTime { - // The timer is scheduled to fire before the maximum response time so we - // leave our timer as is. - return - } - } - - info.state = delayingMember - info.cancelDelayedReportJob() - maxResponseTime = g.calculateDelayTimerDuration(maxResponseTime) - info.delayedReportJob.Schedule(maxResponseTime) - info.delayedReportJobFiresAt = now.Add(maxResponseTime) -} - -// calculateDelayTimerDuration returns a random time between (0, maxRespTime]. -func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration { - // As per RFC 2236 section 3 page 3 (for IGMPv2), - // - // When a host receives a Group-Specific Query, it sets a delay timer to a - // random value selected from the range (0, Max Response Time]... - // - // As per RFC 2710 section 4 page 6 (for MLDv1), - // - // When a node receives a Multicast-Address-Specific Query, if it is - // listening to the queried Multicast Address on the interface from - // which the Query was received, it sets a delay timer for that address - // to a random value selected from the range [0, Maximum Response Delay], - // as above. - if maxRespTime == 0 { - return 0 - } - return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime))) -} diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go deleted file mode 100644 index 60eaea37e..000000000 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ /dev/null @@ -1,812 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" -) - -const ( - addr1 = tcpip.Address("\x01") - addr2 = tcpip.Address("\x02") - addr3 = tcpip.Address("\x03") - addr4 = tcpip.Address("\x04") - - maxUnsolicitedReportDelay = time.Second -) - -var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) - -type mockMulticastGroupProtocolProtectedFields struct { - sync.RWMutex - - genericMulticastGroup ip.GenericMulticastProtocolState - sendReportGroupAddrCount map[tcpip.Address]int - sendLeaveGroupAddrCount map[tcpip.Address]int - makeQueuePackets bool - disabled bool -} - -type mockMulticastGroupProtocol struct { - t *testing.T - - mu mockMulticastGroupProtocolProtectedFields -} - -func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { - m.mu.Lock() - defer m.mu.Unlock() - m.initLocked() - opts.Protocol = m - m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) -} - -func (m *mockMulticastGroupProtocol) initLocked() { - m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) -} - -func (m *mockMulticastGroupProtocol) setEnabled(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.disabled = !v -} - -func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.makeQueuePackets = v -} - -func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.JoinGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleReportLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) -} - -func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) -} - -func (m *mockMulticastGroupProtocol) makeAllNonMember() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.MakeAllNonMemberLocked() -} - -func (m *mockMulticastGroupProtocol) initializeGroups() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.InitializeGroupsLocked() -} - -func (m *mockMulticastGroupProtocol) sendQueuedReports() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.SendQueuedReportsLocked() -} - -// Enabled implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be read locked. -func (m *mockMulticastGroupProtocol) Enabled() bool { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") - } - - return !m.mu.disabled -} - -// SendReport implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - - m.mu.sendReportGroupAddrCount[groupAddress]++ - return !m.mu.makeQueuePackets, nil -} - -// SendLeave implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - - m.mu.sendLeaveGroupAddrCount[groupAddress]++ - return nil -} - -func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendReportGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendReportGroupAddresses { - sendReportGroupAddrCount[a] = 1 - } - - sendLeaveGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddrCount[a] = 1 - } - - diff := cmp.Diff( - &mockMulticastGroupProtocol{ - mu: mockMulticastGroupProtocolProtectedFields{ - sendReportGroupAddrCount: sendReportGroupAddrCount, - sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, - }, - }, - m, - cmp.AllowUnexported(mockMulticastGroupProtocol{}), - cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), - // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t - cmp.FilterPath( - func(p cmp.Path) bool { - switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": - return true - } - return false - }, - cmp.Ignore(), - ), - ) - m.initLocked() - return diff -} - -func TestJoinGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendReports bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendReports: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendReports: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(0)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, - }) - - // Joining a group should send a report immediately and another after - // a random interval between 0 and the maximum unsolicited report delay. - mgp.joinGroup(test.addr) - if test.shouldSendReports { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestLeaveGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendMessages bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendMessages: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendMessages: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(1)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, - }) - - mgp.joinGroup(test.addr) - if test.shouldSendMessages { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Leaving a group should send a leave report immediately and cancel any - // delayed reports. - { - - if !mgp.leaveGroup(test.addr) { - t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) - } - } - if test.shouldSendMessages { - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleReport(t *testing.T) { - tests := []struct { - name string - reportAddr tcpip.Address - expectReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - reportAddr: "", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Unpecified any", - reportAddr: "\x00", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified", - reportAddr: addr1, - expectReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - reportAddr: addr3, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - reportAddr: addr4, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(2)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report for a group we have a timer scheduled for should - // cancel our delayed report timer for the group. - mgp.handleReport(test.reportAddr) - if len(test.expectReportsFor) != 0 { - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleQuery(t *testing.T) { - tests := []struct { - name string - queryAddr tcpip.Address - maxDelay time.Duration - expectQueriedReportsFor []tcpip.Address - expectDelayedReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - queryAddr: "", - maxDelay: 0, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Unpecified any", - queryAddr: "\x00", - maxDelay: 1, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Specified", - queryAddr: addr1, - maxDelay: 2, - expectQueriedReportsFor: []tcpip.Address{addr1}, - expectDelayedReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - queryAddr: addr3, - maxDelay: 3, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - queryAddr: addr4, - maxDelay: 4, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should make us reschedule our delayed report timer - // to some time within the new max response delay. - mgp.handleQuery(test.queryAddr, test.maxDelay) - clock.Advance(test.maxDelay) - if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The groups that were not affected by the query should still send a - // report after the max unsolicited report delay. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestJoinCount(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: time.Second, - }) - - // Set the join count to 2 for a group. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // Only the first join should trigger a report to be sent. - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should still be considered joined after leaving once. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // A leave report should only be sent once the join count reaches 0. - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Leaving once more should actually remove us from the group. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should no longer be joined so we should not have anything to - // leave. - if mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should send the leave reports for each but still consider them locally - // joined. - mgp.makeAllNonMember() - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !mgp.isLocallyJoined(group) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) - } - } - - // Should send the initial set of unsolcited reports. - mgp.initializeGroups() - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -// TestGroupStateNonMember tests that groups do not send packets when in the -// non-member state, but are still considered locally joined. -func TestGroupStateNonMember(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - mgp.setEnabled(false) - - // Joining groups should not send any reports. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should not send any reports. - mgp.handleQuery(addr1, time.Nanosecond) - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Nanosecond) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Leaving groups should not send any leave messages. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestQueuedPackets(t *testing.T) { - clock := faketime.NewManualClock() - mgp := mockMulticastGroupProtocol{t: t} - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - // Joining should trigger a SendReport, but mgp should report that we did not - // send the packet. - mgp.setQueuePackets(true) - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report timer should have been cancelled since we did not send - // the initial report earlier. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send (we should be idle). - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query but mock being unable to send reports again. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to send reports again - we should have a packet queued to - // send. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query again, but mock being unable to send reports. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report should should transition us into the idle member state, - // even if we had a packet queued. We should no longer have any packets to - // send. - mgp.handleReport(addr1) - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // When we fail to send the initial set of reports, incoming reports should - // not affect a newly joined group's reports from being sent. - mgp.setQueuePackets(true) - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.handleReport(addr2) - // Attempting to send queued reports while still unable to send reports should - // not change the host state. - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} diff --git a/pkg/tcpip/network/ip/stats.go b/pkg/tcpip/network/ip/stats.go deleted file mode 100644 index 898f8b356..000000000 --- a/pkg/tcpip/network/ip/stats.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip - -import "gvisor.dev/gvisor/pkg/tcpip" - -// LINT.IfChange(MultiCounterIPStats) - -// MultiCounterIPStats holds IP statistics, each counter may have several -// versions. -type MultiCounterIPStats struct { - // PacketsReceived is the total number of IP packets received from the link - // layer. - PacketsReceived tcpip.MultiCounterStat - - // DisabledPacketsReceived is the total number of IP packets received from the - // link layer when the IP layer is disabled. - DisabledPacketsReceived tcpip.MultiCounterStat - - // InvalidDestinationAddressesReceived is the total number of IP packets - // received with an unknown or invalid destination address. - InvalidDestinationAddressesReceived tcpip.MultiCounterStat - - // InvalidSourceAddressesReceived is the total number of IP packets received - // with a source address that should never have been received on the wire. - InvalidSourceAddressesReceived tcpip.MultiCounterStat - - // PacketsDelivered is the total number of incoming IP packets that are - // successfully delivered to the transport layer. - PacketsDelivered tcpip.MultiCounterStat - - // PacketsSent is the total number of IP packets sent via WritePacket. - PacketsSent tcpip.MultiCounterStat - - // OutgoingPacketErrors is the total number of IP packets which failed to - // write to a link-layer endpoint. - OutgoingPacketErrors tcpip.MultiCounterStat - - // MalformedPacketsReceived is the total number of IP Packets that were - // dropped due to the IP packet header failing validation checks. - MalformedPacketsReceived tcpip.MultiCounterStat - - // MalformedFragmentsReceived is the total number of IP Fragments that were - // dropped due to the fragment failing validation checks. - MalformedFragmentsReceived tcpip.MultiCounterStat - - // IPTablesPreroutingDropped is the total number of IP packets dropped in the - // Prerouting chain. - IPTablesPreroutingDropped tcpip.MultiCounterStat - - // IPTablesInputDropped is the total number of IP packets dropped in the Input - // chain. - IPTablesInputDropped tcpip.MultiCounterStat - - // IPTablesOutputDropped is the total number of IP packets dropped in the - // Output chain. - IPTablesOutputDropped tcpip.MultiCounterStat - - // OptionTSReceived is the number of Timestamp options seen. - OptionTSReceived tcpip.MultiCounterStat - - // OptionRRReceived is the number of Record Route options seen. - OptionRRReceived tcpip.MultiCounterStat - - // OptionUnknownReceived is the number of unknown IP options seen. - OptionUnknownReceived tcpip.MultiCounterStat -} - -// Init sets internal counters to track a and b counters. -func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { - m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) - m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) - m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) - m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) - m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered) - m.PacketsSent.Init(a.PacketsSent, b.PacketsSent) - m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors) - m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived) - m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) - m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) - m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) - m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) - m.OptionTSReceived.Init(a.OptionTSReceived, b.OptionTSReceived) - m.OptionRRReceived.Init(a.OptionRRReceived, b.OptionRRReceived) - m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) -} - -// LINT.ThenChange(:MultiCounterIPStats, ../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 9713c4448..4b21ee79c 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -17,9 +17,9 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/ip", + "//pkg/tcpip/network/internal/fragmentation", + "//pkg/tcpip/network/internal/ip", "//pkg/tcpip/stack", ], ) @@ -40,8 +40,8 @@ go_test( "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/raw", @@ -59,7 +59,7 @@ go_test( library = ":ipv4", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/testutil", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index acc126c3b..12632aceb 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -22,7 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c5e4034ce..250e4846a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -27,8 +27,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" - "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index c57d7f616..dc4db6e5f 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -34,8 +34,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go index bee72c649..5ae73fbfb 100644 --- a/pkg/tcpip/network/ipv4/stats.go +++ b/pkg/tcpip/network/ipv4/stats.go @@ -16,7 +16,7 @@ package ipv4 import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index fbbc6e69c..a637f9d50 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -19,7 +19,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index d75b0b8de..bb9a02ed0 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -19,10 +19,9 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/internal/fragmentation", "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -44,7 +43,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/testutil", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 56bbf1cc3..c5c3ef882 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -30,8 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" - "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1c6c37c91..7e714b50e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -31,7 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 2cc0dfebd..205e36cdd 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go index 0839be3cd..c2758352f 100644 --- a/pkg/tcpip/network/ipv6/stats.go +++ b/pkg/tcpip/network/ipv6/stats.go @@ -16,7 +16,7 @@ package ipv6 import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD deleted file mode 100644 index bd62c4482..000000000 --- a/pkg/tcpip/network/testutil/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - srcs = [ - "testutil.go", - "testutil_unsafe.go", - ], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/fragmentation:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go deleted file mode 100644 index f5fa77b65..000000000 --- a/pkg/tcpip/network/testutil/testutil.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil defines types and functions used to test Network Layer -// functionality such as IP fragmentation. -package testutil - -import ( - "fmt" - "math/rand" - "reflect" - "strings" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// MockLinkEndpoint is an endpoint used for testing, it stores packets written -// to it and can mock errors. -type MockLinkEndpoint struct { - // WrittenPackets is where packets written to the endpoint are stored. - WrittenPackets []*stack.PacketBuffer - - mtu uint32 - err tcpip.Error - allowPackets int -} - -// NewMockLinkEndpoint creates a new MockLinkEndpoint. -// -// err is the error that will be returned once allowPackets packets are written -// to the endpoint. -func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { - return &MockLinkEndpoint{ - mtu: mtu, - err: err, - allowPackets: allowPackets, - } -} - -// MTU implements LinkEndpoint.MTU. -func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } - -// Capabilities implements LinkEndpoint.Capabilities. -func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - var n int - - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { - return n, err - } - n++ - } - - return n, nil -} - -// Attach implements LinkEndpoint.Attach. -func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*MockLinkEndpoint) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*MockLinkEndpoint) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - -// MakeRandPkt generates a randomized packet. transportHeaderLength indicates -// how many random bytes will be copied in the Transport Header. -// extraHeaderReserveLength indicates how much extra space will be reserved for -// the other headers. The payload is made from Views of the sizes listed in -// viewSizes. -func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { - var views buffer.VectorisedView - - for _, s := range viewSizes { - newView := buffer.NewView(s) - if _, err := rand.Read(newView); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - views.AppendView(newView) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, - Data: views, - }) - pkt.NetworkProtocolNumber = proto - if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - return pkt -} - -func checkFieldCounts(ref, multi reflect.Value) error { - refTypeName := ref.Type().Name() - multiTypeName := multi.Type().Name() - refNumField := ref.NumField() - multiNumField := multi.NumField() - - if refNumField != multiNumField { - return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) - } - - return nil -} - -func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { - s, ok := ref.Addr().Interface().(**tcpip.StatCounter) - if !ok { - return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) - } - - // The field names are expected to match (case insensitive). - if !strings.EqualFold(refName, multiName) { - return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) - } - - base := (*s).Value() - m.Increment() - if (*s).Value() != base+1 { - return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) - } - - return nil -} - -// ValidateMultiCounterStats verifies that every counter stored in multi is -// correctly tracking its counterpart in the given counters. -func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { - for _, c := range counters { - if err := checkFieldCounts(c, multi); err != nil { - return err - } - } - - for i := 0; i < multi.NumField(); i++ { - multiName := multi.Type().Field(i).Name - multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) - - if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { - for _, c := range counters { - if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { - return err - } - } - } else { - var countersNextField []reflect.Value - for _, c := range counters { - countersNextField = append(countersNextField, c.Field(i)) - } - if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/tcpip/network/testutil/testutil_unsafe.go b/pkg/tcpip/network/testutil/testutil_unsafe.go deleted file mode 100644 index 5ff764800..000000000 --- a/pkg/tcpip/network/testutil/testutil_unsafe.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "reflect" - "unsafe" -) - -// unsafeExposeUnexportedFields takes a Value and returns a version of it in -// which even unexported fields can be read and written. -func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { - return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() -} -- cgit v1.2.3