diff options
Diffstat (limited to 'pkg/tcpip/network/internal')
11 files changed, 321 insertions, 2319 deletions
diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD deleted file mode 100644 index 274f09092..000000000 --- a/pkg/tcpip/network/internal/fragmentation/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "reassembler_list", - out = "reassembler_list.go", - package = "fragmentation", - prefix = "reassembler", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*reassembler", - "Linker": "*reassembler", - }, -) - -go_library( - name = "fragmentation", - srcs = [ - "fragmentation.go", - "reassembler.go", - "reassembler_list.go", - ], - visibility = [ - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "fragmentation_test", - size = "small", - srcs = [ - "fragmentation_test.go", - "reassembler_test.go", - ], - library = ":fragmentation", - deps = [ - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go new file mode 100644 index 000000000..21c5774e9 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go @@ -0,0 +1,68 @@ +// automatically generated by stateify. + +package fragmentation + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *reassemblerList) StateTypeName() string { + return "pkg/tcpip/network/internal/fragmentation.reassemblerList" +} + +func (l *reassemblerList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *reassemblerList) beforeSave() {} + +// +checklocksignore +func (l *reassemblerList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *reassemblerList) afterLoad() {} + +// +checklocksignore +func (l *reassemblerList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *reassemblerEntry) StateTypeName() string { + return "pkg/tcpip/network/internal/fragmentation.reassemblerEntry" +} + +func (e *reassemblerEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *reassemblerEntry) beforeSave() {} + +// +checklocksignore +func (e *reassemblerEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *reassemblerEntry) afterLoad() {} + +// +checklocksignore +func (e *reassemblerEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*reassemblerList)(nil)) + state.Register((*reassemblerEntry)(nil)) +} diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go deleted file mode 100644 index dadfc28cc..000000000 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ /dev/null @@ -1,648 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "errors" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// reassembleTimeout is dummy timeout used for testing, where the clock never -// advances. -const reassembleTimeout = 1 - -// vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -func pkt(size int, pieces ...string) *stack.PacketBuffer { - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv(size, pieces...), - }) -} - -type processInput struct { - id FragmentID - first uint16 - last uint16 - more bool - proto uint8 - pkt *stack.PacketBuffer -} - -type processOutput struct { - vv buffer.VectorisedView - proto uint8 - done bool -} - -var processTestCases = []struct { - comment string - in []processInput - out []processOutput -}{ - { - comment: "One ID", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, - { - comment: "Next Header protocol mismatch", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), proto: 6, done: true}, - }, - }, - { - comment: "Two IDs", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "ab", "cd"), done: true}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, -} - -func TestFragmentationProcess(t *testing.T) { - for _, c := range processTestCases { - t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) - firstFragmentProto := c.in[0].proto - for i, in := range c.in { - resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) - if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, err) - } - if done != c.out[i].done { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", - in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) - } - if c.out[i].done { - if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) - } - if firstFragmentProto != proto { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", - in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) - } - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Process(%d) did not remove buffer from reassemblers", i) - } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Process(%d) did not remove buffer from rList", i) - } - } - } - } - }) - } -} - -func TestReassemblingTimeout(t *testing.T) { - const ( - reassemblyTimeout = time.Millisecond - protocol = 0xff - ) - - type fragment struct { - first uint16 - last uint16 - more bool - data string - } - - type event struct { - // name is a nickname of this event. - name string - - // clockAdvance is a duration to advance the clock. The clock advances - // before a fragment specified in the fragment field is processed. - clockAdvance time.Duration - - // fragment is a fragment to process. This can be nil if there is no - // fragment to process. - fragment *fragment - - // expectDone is true if the fragmentation instance should report the - // reassembly is done after the fragment is processd. - expectDone bool - - // memSizeAfterEvent is the expected memory size of the fragmentation - // instance after the event. - memSizeAfterEvent int - } - - memSizeOfFrags := func(frags ...*fragment) int { - var size int - for _, frag := range frags { - size += pkt(len(frag.data), frag.data).MemSize() - } - return size - } - - half1 := &fragment{first: 0, last: 0, more: true, data: "0"} - half2 := &fragment{first: 1, last: 1, more: false, data: "1"} - - tests := []struct { - name string - events []event - }{ - { - name: "half1 and half2 are reassembled successfully", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half2", - fragment: half2, - expectDone: true, - memSizeAfterEvent: 0, - }, - }, - }, - { - name: "half1 timeout, half2 timeout", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - { - name: "half2", - fragment: half2, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) - for _, event := range test.events { - clock.Advance(event.clockAdvance) - if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) - if err != nil { - t.Fatalf("%s: f.Process failed: %s", event.name, err) - } - if done != event.expectDone { - t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) - } - } - if got, want := f.memSize, event.memSizeAfterEvent; got != want { - t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) - } - } - }) - } -} - -func TestMemoryLimits(t *testing.T) { - lowLimit := pkt(1, "0").MemSize() - highLimit := 3 * lowLimit // Allow at most 3 such packets. - f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { - t.Fatal(err) - } - // Send first fragment with id = 1. - if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { - t.Fatal(err) - } - // Send first fragment with id = 2. - if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { - t.Fatal(err) - } - - // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be - // evicted. - if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { - t.Fatal(err) - } - - if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { - t.Errorf("Memory limits are not respected: id=0 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { - t.Errorf("Memory limits are not respected: id=1 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { - t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") - } -} - -func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - memSize := pkt(1, "0").MemSize() - f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { - t.Fatal(err) - } - // Send the same packet again. - if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { - t.Fatal(err) - } - - if got, want := f.memSize, memSize; got != want { - t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) - } -} - -func TestErrors(t *testing.T) { - tests := []struct { - name string - blockSize uint16 - first uint16 - last uint16 - more bool - data string - err error - }{ - { - name: "exact block size without more", - blockSize: 2, - first: 2, - last: 3, - more: false, - data: "01", - }, - { - name: "exact block size with more", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "01", - }, - { - name: "exact block size with more and extra data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "012", - err: ErrInvalidArgs, - }, - { - name: "exact block size with more and too little data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size with more", - blockSize: 2, - first: 2, - last: 2, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size without more", - blockSize: 2, - first: 2, - last: 2, - more: false, - data: "0", - }, - { - name: "first not a multiple of block size", - blockSize: 2, - first: 3, - last: 4, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - { - name: "first more than last", - blockSize: 2, - first: 4, - last: 3, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) - if !errors.Is(err, test.err) { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) - } - if done { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) - } - }) - } -} - -type fragmentInfo struct { - remaining int - copied int - offset int - more bool -} - -func TestPacketFragmenter(t *testing.T) { - const ( - reserve = 60 - proto = 0 - ) - - tests := []struct { - name string - fragmentPayloadLen uint32 - transportHeaderLen int - payloadSize int - wantFragments []fragmentInfo - }{ - { - name: "Packet exactly fits in MTU", - fragmentPayloadLen: 1280, - transportHeaderLen: 0, - payloadSize: 1280, - wantFragments: []fragmentInfo{ - {remaining: 0, copied: 1280, offset: 0, more: false}, - }, - }, - { - name: "Packet exactly does not fit in MTU", - fragmentPayloadLen: 1000, - transportHeaderLen: 0, - payloadSize: 1001, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 1000, offset: 0, more: true}, - {remaining: 0, copied: 1, offset: 1000, more: false}, - }, - }, - { - name: "Packet has a transport header", - fragmentPayloadLen: 560, - transportHeaderLen: 40, - payloadSize: 560, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 560, offset: 0, more: true}, - {remaining: 0, copied: 40, offset: 560, more: false}, - }, - }, - { - name: "Packet has a huge transport header", - fragmentPayloadLen: 500, - transportHeaderLen: 1300, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {remaining: 3, copied: 500, offset: 0, more: true}, - {remaining: 2, copied: 500, offset: 500, more: true}, - {remaining: 1, copied: 500, offset: 1000, more: true}, - {remaining: 0, copied: 300, offset: 1500, more: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) - originalPayload := stack.PayloadSince(pkt.TransportHeader()) - var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) - for i := 0; ; i++ { - fragPkt, offset, copied, more := pf.BuildNextFragment() - wantFragment := test.wantFragments[i] - if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { - t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) - } - if copied != wantFragment.copied { - t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) - } - if offset != wantFragment.offset { - t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) - } - if more != wantFragment.more { - t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) - } - if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) - } - if got := fragPkt.AvailableHeaderBytes(); got != reserve { - t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) - } - if got := fragPkt.TransportHeader().View().Size(); got != 0 { - t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) - } - reassembledPayload.AppendViews(fragPkt.Data().Views()) - if !more { - if i != len(test.wantFragments)-1 { - t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) - } - break - } - } - if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" { - t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type testTimeoutHandler struct { - pkt *stack.PacketBuffer -} - -func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { - h.pkt = pkt -} - -func TestTimeoutHandler(t *testing.T) { - const ( - proto = 99 - ) - - pk1 := pkt(1, "1") - pk2 := pkt(1, "2") - - type processParam struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - } - - tests := []struct { - name string - params []processParam - wantError bool - wantPkt *stack.PacketBuffer - }{ - { - name: "onTimeout runs", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "no first fragment", - params: []processParam{ - { - first: 1, - last: 1, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: nil, - }, - { - name: "second pkt is ignored", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - { - first: 0, - last: 0, - more: true, - pkt: pk2, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "invalid args - first is greater than last", - params: []processParam{ - { - first: 1, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: true, - wantPkt: nil, - }, - } - - id := FragmentID{ID: 0} - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - handler := &testTimeoutHandler{pkt: nil} - - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) - - for _, p := range test.params { - if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { - t.Errorf("f.Process error = %s", err) - } - } - if !test.wantError { - r, ok := f.reassemblers[id] - if !ok { - t.Fatal("Reassembler not found") - } - f.release(r, true) - } - switch { - case handler.pkt != nil && test.wantPkt == nil: - t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView()) - case handler.pkt == nil && test.wantPkt != nil: - t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView()) - case handler.pkt != nil && test.wantPkt != nil: - if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" { - t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_list.go b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go new file mode 100644 index 000000000..673bb11b0 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go @@ -0,0 +1,221 @@ +package fragmentation + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type reassemblerElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type reassemblerList struct { + head *reassembler + tail *reassembler +} + +// Reset resets list l to the empty state. +func (l *reassemblerList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *reassemblerList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *reassemblerList) Front() *reassembler { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *reassemblerList) Back() *reassembler { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *reassemblerList) Len() (count int) { + for e := l.Front(); e != nil; e = (reassemblerElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *reassemblerList) PushFront(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *reassemblerList) PushBack(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *reassemblerList) PushBackList(m *reassemblerList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) + reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *reassemblerList) InsertAfter(b, e *reassembler) { + bLinker := reassemblerElementMapper{}.linkerFor(b) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *reassemblerList) InsertBefore(a, e *reassembler) { + aLinker := reassemblerElementMapper{}.linkerFor(a) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *reassemblerList) Remove(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + reassemblerElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type reassemblerEntry struct { + next *reassembler + prev *reassembler +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *reassemblerEntry) Next() *reassembler { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *reassemblerEntry) Prev() *reassembler { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *reassemblerEntry) SetNext(elem *reassembler) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *reassemblerEntry) SetPrev(elem *reassembler) { + e.prev = elem +} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go deleted file mode 100644 index cfd9f00ef..000000000 --- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "bytes" - "math" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type processParams struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - wantDone bool - wantError error -} - -func TestReassemblerProcess(t *testing.T) { - const proto = 99 - - v := func(size int) buffer.View { - payload := buffer.NewView(size) - for i := 1; i < size; i++ { - payload[i] = uint8(i) * 3 - } - return payload - } - - pkt := func(sizes ...int) *stack.PacketBuffer { - var vv buffer.VectorisedView - for _, size := range sizes { - vv.AppendView(v(size)) - } - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - } - - var tests = []struct { - name string - params []processParams - want []hole - wantPkt *stack.PacketBuffer - }{ - { - name: "No fragments", - params: nil, - want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, - }, - { - name: "One fragment at beginning", - params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, - {first: 2, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment in the middle", - params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, - {first: 0, last: 0, filled: false, final: false}, - {first: 3, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment at the end", - params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, - {first: 0, last: 0, filled: false}, - }, - }, - { - name: "One fragment completing a packet", - params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: true}, - }, - wantPkt: pkt(2), - }, - { - name: "Two fragments completing a packet", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a duplicate", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a partial duplicate", - params: []processParams{ - {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, - {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 3, filled: true, final: false}, - {first: 4, last: 5, filled: true, final: true}, - }, - wantPkt: pkt(4, 2), - }, - { - name: "Two overlapping fragments", - params: []processParams{ - {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, - {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, - }, - want: []hole{ - {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, - {first: 11, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "Two final fragments with different ends", - params: []processParams{ - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, - {first: 0, last: 9, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate, with different ends", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - var resPkt *stack.PacketBuffer - var isDone bool - for _, param := range test.params { - pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) - if done != param.wantDone || err != param.wantError { - t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) - } - if done { - resPkt = pkt - isDone = true - } - } - - ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } - cmpPktData := func(a, b *stack.PacketBuffer) bool { - if a == nil || b == nil { - return a == b - } - return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView()) - } - - if isDone { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - // Do not compare pkt in hole. Data will be altered. - cmp.Comparer(ignorePkt), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { - t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) - } - } else { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - cmp.Comparer(cmpPktData), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD deleted file mode 100644 index fd944ce99..000000000 --- a/pkg/tcpip/network/internal/ip/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ip", - srcs = [ - "duplicate_address_detection.go", - "errors.go", - "generic_multicast_protocol.go", - "stats.go", - ], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ip_x_test", - size = "small", - srcs = [ - "duplicate_address_detection_test.go", - "generic_multicast_protocol_test.go", - ], - deps = [ - ":ip", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/faketime", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go deleted file mode 100644 index 24687cf06..000000000 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ /dev/null @@ -1,381 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "bytes" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type mockDADProtocol struct { - t *testing.T - - mu struct { - sync.Mutex - - dad ip.DAD - sentNonces map[tcpip.Address][][]byte - } -} - -func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.DADOptions) { - m.mu.Lock() - defer m.mu.Unlock() - - m.t = t - opts.Protocol = m - m.mu.dad.Init(&m.mu, c, opts) - m.initLocked() -} - -func (m *mockDADProtocol) initLocked() { - m.mu.sentNonces = make(map[tcpip.Address][][]byte) -} - -func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce) - return nil -} - -func (m *mockDADProtocol) check(addrs []tcpip.Address) string { - sentNonces := make(map[tcpip.Address][][]byte) - for _, a := range addrs { - sentNonces[a] = append(sentNonces[a], nil) - } - - return m.checkWithNonce(sentNonces) -} - -func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string { - m.mu.Lock() - defer m.mu.Unlock() - - diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces) - m.initLocked() - return diff -} - -func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.dad.CheckDuplicateAddressLocked(addr, h) -} - -func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.dad.StopLocked(addr, reason) -} - -func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce) -} - -func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.dad.SetConfigsLocked(c) -} - -const ( - addr1 = tcpip.Address("\x01") - addr2 = tcpip.Address("\x02") - addr3 = tcpip.Address("\x03") - addr4 = tcpip.Address("\x04") -) - -type dadResult struct { - Addr tcpip.Address - R stack.DADResult -} - -func handler(ch chan<- dadResult, a tcpip.Address) func(stack.DADResult) { - return func(r stack.DADResult) { - ch <- dadResult{Addr: a, R: r} - } -} - -func TestDADCheckDuplicateAddress(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dad.init(t, stack.DADConfigurations{}, ip.DADOptions{ - Clock: clock, - }) - - ch := make(chan dadResult, 2) - - // DAD should initially be disabled. - if res := dad.checkDuplicateAddress(addr1, handler(nil, "")); res != stack.DADDisabled { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) - } - // Wait for any initially fired timers to complete. - clock.RunImmediatelyScheduledJobs() - if diff := dad.check(nil); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - // Enable and request DAD. - dadConfigs1 := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - dad.setConfigs(dadConfigs1) - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr1}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - // The second request for DAD on the same address should use the original - // request since it has not completed yet. - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check(nil); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - dadConfigs2 := stack.DADConfigurations{ - DupAddrDetectTransmits: 2, - RetransmitTimer: time.Second, - } - dad.setConfigs(dadConfigs2) - // A new address should start a new DAD process. - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr2}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - // Make sure DAD for addr1 only resolves after the expected timeout. - const delta = time.Nanosecond - dadConfig1Duration := time.Duration(dadConfigs1.DupAddrDetectTransmits) * dadConfigs1.RetransmitTimer - clock.Advance(dadConfig1Duration - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig1Duration, r) - default: - } - clock.Advance(delta) - for i := 0; i < 2; i++ { - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff) - } - } - - // Make sure DAD for addr2 only resolves after the expected timeout. - dadConfig2Duration := time.Duration(dadConfigs2.DupAddrDetectTransmits) * dadConfigs2.RetransmitTimer - clock.Advance(dadConfig2Duration - dadConfig1Duration - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig2Duration, r) - default: - } - clock.Advance(delta) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should be able to restart DAD for addr2 after it resolved. - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - clock.Advance(dadConfig2Duration) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore results. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } -} - -func TestDADStop(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - dad.init(t, dadConfigs, ip.DADOptions{ - Clock: clock, - }) - - ch := make(chan dadResult, 1) - - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - dad.stop(addr1, &stack.DADAborted{}) - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - dad.stop(addr2, &stack.DADDupAddrDetected{}) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer - clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should be able to restart DAD for an address we stopped DAD on. - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr1}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore updates. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } -} - -func TestNonce(t *testing.T) { - const ( - nonceSize = 2 - - extendRequestAttempts = 2 - - dupAddrDetectTransmits = 2 - extendTransmits = 5 - ) - - var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte - for i := range secureRNGBytes { - secureRNGBytes[i] = byte(i) - } - - tests := []struct { - name string - mockedReceivedNonce []byte - expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition - expectedTransmits int - }{ - { - name: "not matching", - mockedReceivedNonce: []byte{0, 0}, - expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual}, - expectedTransmits: dupAddrDetectTransmits, - }, - { - name: "matching nonce", - mockedReceivedNonce: secureRNGBytes[:nonceSize], - expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended}, - expectedTransmits: dupAddrDetectTransmits + extendTransmits, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: dupAddrDetectTransmits, - RetransmitTimer: time.Second, - } - - var secureRNG bytes.Reader - secureRNG.Reset(secureRNGBytes[:]) - dad.init(t, dadConfigs, ip.DADOptions{ - Clock: clock, - SecureRNG: &secureRNG, - NonceSize: nonceSize, - ExtendDADTransmits: extendTransmits, - }) - - ch := make(chan dadResult, 1) - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - - clock.RunImmediatelyScheduledJobs() - for i, want := range test.expectedResults { - if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { - t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) - } - } - - for i := 0; i < test.expectedTransmits; i++ { - if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{ - addr1: { - secureRNGBytes[nonceSize*i:][:nonceSize], - }, - }); diff != "" { - t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff) - } - - clock.Advance(dadConfigs.RetransmitTimer) - } - - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore updates. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } - }) - } -} diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go deleted file mode 100644 index 1261ad414..000000000 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ /dev/null @@ -1,808 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" -) - -const maxUnsolicitedReportDelay = time.Second - -var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) - -type mockMulticastGroupProtocolProtectedFields struct { - sync.RWMutex - - genericMulticastGroup ip.GenericMulticastProtocolState - sendReportGroupAddrCount map[tcpip.Address]int - sendLeaveGroupAddrCount map[tcpip.Address]int - makeQueuePackets bool - disabled bool -} - -type mockMulticastGroupProtocol struct { - t *testing.T - - skipProtocolAddress tcpip.Address - - mu mockMulticastGroupProtocolProtectedFields -} - -func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { - m.mu.Lock() - defer m.mu.Unlock() - m.initLocked() - opts.Protocol = m - m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) -} - -func (m *mockMulticastGroupProtocol) initLocked() { - m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) -} - -func (m *mockMulticastGroupProtocol) setEnabled(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.disabled = !v -} - -func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.makeQueuePackets = v -} - -func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.JoinGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleReportLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) -} - -func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) -} - -func (m *mockMulticastGroupProtocol) makeAllNonMember() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.MakeAllNonMemberLocked() -} - -func (m *mockMulticastGroupProtocol) initializeGroups() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.InitializeGroupsLocked() -} - -func (m *mockMulticastGroupProtocol) sendQueuedReports() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.SendQueuedReportsLocked() -} - -// Enabled implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be read locked. -func (m *mockMulticastGroupProtocol) Enabled() bool { - if m.mu.TryLock() { - m.mu.Unlock() // +checklocksforce: TryLock. - m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") - } - - return !m.mu.disabled -} - -// SendReport implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { - if m.mu.TryLock() { - m.mu.Unlock() // +checklocksforce: TryLock. - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() // +checklocksforce: TryLock. - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - - m.mu.sendReportGroupAddrCount[groupAddress]++ - return !m.mu.makeQueuePackets, nil -} - -// SendLeave implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { - if m.mu.TryLock() { - m.mu.Unlock() // +checklocksforce: TryLock. - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() // +checklocksforce: TryLock. - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - - m.mu.sendLeaveGroupAddrCount[groupAddress]++ - return nil -} - -// ShouldPerformProtocol implements ip.MulticastGroupProtocol. -func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool { - return groupAddress != m.skipProtocolAddress -} - -func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendReportGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendReportGroupAddresses { - sendReportGroupAddrCount[a] = 1 - } - - sendLeaveGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddrCount[a] = 1 - } - - diff := cmp.Diff( - &mockMulticastGroupProtocol{ - mu: mockMulticastGroupProtocolProtectedFields{ - sendReportGroupAddrCount: sendReportGroupAddrCount, - sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, - }, - }, - m, - cmp.AllowUnexported(mockMulticastGroupProtocol{}), - cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), - // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t - cmp.FilterPath( - func(p cmp.Path) bool { - switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress": - return true - default: - return false - } - }, - cmp.Ignore(), - ), - ) - m.initLocked() - return diff -} - -func TestJoinGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendReports bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendReports: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendReports: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(0)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - // Joining a group should send a report immediately and another after - // a random interval between 0 and the maximum unsolicited report delay. - mgp.joinGroup(test.addr) - if test.shouldSendReports { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestLeaveGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendMessages bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendMessages: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendMessages: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(1)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - mgp.joinGroup(test.addr) - if test.shouldSendMessages { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Leaving a group should send a leave report immediately and cancel any - // delayed reports. - { - - if !mgp.leaveGroup(test.addr) { - t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) - } - } - if test.shouldSendMessages { - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleReport(t *testing.T) { - tests := []struct { - name string - reportAddr tcpip.Address - expectReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - reportAddr: "", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Unpecified any", - reportAddr: "\x00", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified", - reportAddr: addr1, - expectReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - reportAddr: addr3, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - reportAddr: addr4, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(2)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report for a group we have a timer scheduled for should - // cancel our delayed report timer for the group. - mgp.handleReport(test.reportAddr) - if len(test.expectReportsFor) != 0 { - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleQuery(t *testing.T) { - tests := []struct { - name string - queryAddr tcpip.Address - maxDelay time.Duration - expectQueriedReportsFor []tcpip.Address - expectDelayedReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - queryAddr: "", - maxDelay: 0, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Unpecified any", - queryAddr: "\x00", - maxDelay: 1, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Specified", - queryAddr: addr1, - maxDelay: 2, - expectQueriedReportsFor: []tcpip.Address{addr1}, - expectDelayedReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - queryAddr: addr3, - maxDelay: 3, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - queryAddr: addr4, - maxDelay: 4, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should make us reschedule our delayed report timer - // to some time within the new max response delay. - mgp.handleQuery(test.queryAddr, test.maxDelay) - clock.Advance(test.maxDelay) - if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The groups that were not affected by the query should still send a - // report after the max unsolicited report delay. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestJoinCount(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: time.Second, - }) - - // Set the join count to 2 for a group. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // Only the first join should trigger a report to be sent. - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should still be considered joined after leaving once. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // A leave report should only be sent once the join count reaches 0. - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Leaving once more should actually remove us from the group. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should no longer be joined so we should not have anything to - // leave. - if mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should send the leave reports for each but still consider them locally - // joined. - mgp.makeAllNonMember() - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !mgp.isLocallyJoined(group) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) - } - } - - // Should send the initial set of unsolcited reports. - mgp.initializeGroups() - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -// TestGroupStateNonMember tests that groups do not send packets when in the -// non-member state, but are still considered locally joined. -func TestGroupStateNonMember(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - mgp.setEnabled(false) - - // Joining groups should not send any reports. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should not send any reports. - mgp.handleQuery(addr1, time.Nanosecond) - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Nanosecond) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Leaving groups should not send any leave messages. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestQueuedPackets(t *testing.T) { - clock := faketime.NewManualClock() - mgp := mockMulticastGroupProtocol{t: t} - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - // Joining should trigger a SendReport, but mgp should report that we did not - // send the packet. - mgp.setQueuePackets(true) - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report timer should have been cancelled since we did not send - // the initial report earlier. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send (we should be idle). - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query but mock being unable to send reports again. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to send reports again - we should have a packet queued to - // send. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query again, but mock being unable to send reports. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report should should transition us into the idle member state, - // even if we had a packet queued. We should no longer have any packets to - // send. - mgp.handleReport(addr1) - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // When we fail to send the initial set of reports, incoming reports should - // not affect a newly joined group's reports from being sent. - mgp.setQueuePackets(true) - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.handleReport(addr2) - // Attempting to send queued reports while still unable to send reports should - // not change the host state. - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} diff --git a/pkg/tcpip/network/internal/ip/ip_state_autogen.go b/pkg/tcpip/network/internal/ip/ip_state_autogen.go new file mode 100644 index 000000000..360922bfe --- /dev/null +++ b/pkg/tcpip/network/internal/ip/ip_state_autogen.go @@ -0,0 +1,32 @@ +// automatically generated by stateify. + +package ip + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (e *ErrMessageTooLong) StateTypeName() string { + return "pkg/tcpip/network/internal/ip.ErrMessageTooLong" +} + +func (e *ErrMessageTooLong) StateFields() []string { + return []string{} +} + +func (e *ErrMessageTooLong) beforeSave() {} + +// +checklocksignore +func (e *ErrMessageTooLong) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrMessageTooLong) afterLoad() {} + +// +checklocksignore +func (e *ErrMessageTooLong) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*ErrMessageTooLong)(nil)) +} diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD deleted file mode 100644 index a180e5c75..000000000 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - srcs = ["testutil.go"], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/internal/fragmentation:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - "//pkg/tcpip/tests/integration:__pkg__", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go deleted file mode 100644 index 4d4d98caf..000000000 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil defines types and functions used to test Network Layer -// functionality such as IP fragmentation. -package testutil - -import ( - "fmt" - "math/rand" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// MockLinkEndpoint is an endpoint used for testing, it stores packets written -// to it and can mock errors. -type MockLinkEndpoint struct { - // WrittenPackets is where packets written to the endpoint are stored. - WrittenPackets []*stack.PacketBuffer - - mtu uint32 - err tcpip.Error - allowPackets int -} - -// NewMockLinkEndpoint creates a new MockLinkEndpoint. -// -// err is the error that will be returned once allowPackets packets are written -// to the endpoint. -func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { - return &MockLinkEndpoint{ - mtu: mtu, - err: err, - allowPackets: allowPackets, - } -} - -// MTU implements LinkEndpoint.MTU. -func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } - -// Capabilities implements LinkEndpoint.Capabilities. -func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - var n int - - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := ep.WritePacket(r, protocol, pkt); err != nil { - return n, err - } - n++ - } - - return n, nil -} - -// Attach implements LinkEndpoint.Attach. -func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*MockLinkEndpoint) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*MockLinkEndpoint) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - -// WriteRawPacket implements stack.LinkEndpoint. -func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -// MakeRandPkt generates a randomized packet. transportHeaderLength indicates -// how many random bytes will be copied in the Transport Header. -// extraHeaderReserveLength indicates how much extra space will be reserved for -// the other headers. The payload is made from Views of the sizes listed in -// viewSizes. -func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { - var views buffer.VectorisedView - - for _, s := range viewSizes { - newView := buffer.NewView(s) - if _, err := rand.Read(newView); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - views.AppendView(newView) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, - Data: views, - }) - pkt.NetworkProtocolNumber = proto - if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - return pkt -} |