summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/internal/fragmentation
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network/internal/fragmentation')
-rw-r--r--pkg/tcpip/network/internal/fragmentation/BUILD54
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go68
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_test.go648
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_list.go221
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_test.go233
5 files changed, 289 insertions, 935 deletions
diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD
deleted file mode 100644
index 274f09092..000000000
--- a/pkg/tcpip/network/internal/fragmentation/BUILD
+++ /dev/null
@@ -1,54 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "reassembler_list",
- out = "reassembler_list.go",
- package = "fragmentation",
- prefix = "reassembler",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*reassembler",
- "Linker": "*reassembler",
- },
-)
-
-go_library(
- name = "fragmentation",
- srcs = [
- "fragmentation.go",
- "reassembler.go",
- "reassembler_list.go",
- ],
- visibility = [
- "//pkg/tcpip/network/ipv4:__pkg__",
- "//pkg/tcpip/network/ipv6:__pkg__",
- ],
- deps = [
- "//pkg/log",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "fragmentation_test",
- size = "small",
- srcs = [
- "fragmentation_test.go",
- "reassembler_test.go",
- ],
- library = ":fragmentation",
- deps = [
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/network/internal/testutil",
- "//pkg/tcpip/stack",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go
new file mode 100644
index 000000000..21c5774e9
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go
@@ -0,0 +1,68 @@
+// automatically generated by stateify.
+
+package fragmentation
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (l *reassemblerList) StateTypeName() string {
+ return "pkg/tcpip/network/internal/fragmentation.reassemblerList"
+}
+
+func (l *reassemblerList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *reassemblerList) beforeSave() {}
+
+// +checklocksignore
+func (l *reassemblerList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *reassemblerList) afterLoad() {}
+
+// +checklocksignore
+func (l *reassemblerList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *reassemblerEntry) StateTypeName() string {
+ return "pkg/tcpip/network/internal/fragmentation.reassemblerEntry"
+}
+
+func (e *reassemblerEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *reassemblerEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *reassemblerEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *reassemblerEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *reassemblerEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*reassemblerList)(nil))
+ state.Register((*reassemblerEntry)(nil))
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
deleted file mode 100644
index dadfc28cc..000000000
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
+++ /dev/null
@@ -1,648 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "errors"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-// reassembleTimeout is dummy timeout used for testing, where the clock never
-// advances.
-const reassembleTimeout = 1
-
-// vv is a helper to build VectorisedView from different strings.
-func vv(size int, pieces ...string) buffer.VectorisedView {
- views := make([]buffer.View, len(pieces))
- for i, p := range pieces {
- views[i] = []byte(p)
- }
-
- return buffer.NewVectorisedView(size, views)
-}
-
-func pkt(size int, pieces ...string) *stack.PacketBuffer {
- return stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv(size, pieces...),
- })
-}
-
-type processInput struct {
- id FragmentID
- first uint16
- last uint16
- more bool
- proto uint8
- pkt *stack.PacketBuffer
-}
-
-type processOutput struct {
- vv buffer.VectorisedView
- proto uint8
- done bool
-}
-
-var processTestCases = []struct {
- comment string
- in []processInput
- out []processOutput
-}{
- {
- comment: "One ID",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "01", "23"), done: true},
- },
- },
- {
- comment: "Next Header protocol mismatch",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "01", "23"), proto: 6, done: true},
- },
- },
- {
- comment: "Two IDs",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")},
- {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "ab", "cd"), done: true},
- {vv: vv(4, "01", "23"), done: true},
- },
- },
-}
-
-func TestFragmentationProcess(t *testing.T) {
- for _, c := range processTestCases {
- t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
- firstFragmentProto := c.in[0].proto
- for i, in := range c.in {
- resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
- if err != nil {
- t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
- in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
- }
- if done != c.out[i].done {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
- in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
- }
- if c.out[i].done {
- if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s",
- in.id, in.first, in.last, in.more, in.proto, in.pkt, diff)
- }
- if firstFragmentProto != proto {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
- in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
- }
- if _, ok := f.reassemblers[in.id]; ok {
- t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
- }
- for n := f.rList.Front(); n != nil; n = n.Next() {
- if n.id == in.id {
- t.Errorf("Process(%d) did not remove buffer from rList", i)
- }
- }
- }
- }
- })
- }
-}
-
-func TestReassemblingTimeout(t *testing.T) {
- const (
- reassemblyTimeout = time.Millisecond
- protocol = 0xff
- )
-
- type fragment struct {
- first uint16
- last uint16
- more bool
- data string
- }
-
- type event struct {
- // name is a nickname of this event.
- name string
-
- // clockAdvance is a duration to advance the clock. The clock advances
- // before a fragment specified in the fragment field is processed.
- clockAdvance time.Duration
-
- // fragment is a fragment to process. This can be nil if there is no
- // fragment to process.
- fragment *fragment
-
- // expectDone is true if the fragmentation instance should report the
- // reassembly is done after the fragment is processd.
- expectDone bool
-
- // memSizeAfterEvent is the expected memory size of the fragmentation
- // instance after the event.
- memSizeAfterEvent int
- }
-
- memSizeOfFrags := func(frags ...*fragment) int {
- var size int
- for _, frag := range frags {
- size += pkt(len(frag.data), frag.data).MemSize()
- }
- return size
- }
-
- half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
- half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
-
- tests := []struct {
- name string
- events []event
- }{
- {
- name: "half1 and half2 are reassembled successfully",
- events: []event{
- {
- name: "half1",
- fragment: half1,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half2",
- fragment: half2,
- expectDone: true,
- memSizeAfterEvent: 0,
- },
- },
- },
- {
- name: "half1 timeout, half2 timeout",
- events: []event{
- {
- name: "half1",
- fragment: half1,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half1 just before reassembly timeout",
- clockAdvance: reassemblyTimeout - 1,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half1 reassembly timeout",
- clockAdvance: 1,
- memSizeAfterEvent: 0,
- },
- {
- name: "half2",
- fragment: half2,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half2),
- },
- {
- name: "half2 just before reassembly timeout",
- clockAdvance: reassemblyTimeout - 1,
- memSizeAfterEvent: memSizeOfFrags(half2),
- },
- {
- name: "half2 reassembly timeout",
- clockAdvance: 1,
- memSizeAfterEvent: 0,
- },
- },
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil)
- for _, event := range test.events {
- clock.Advance(event.clockAdvance)
- if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data))
- if err != nil {
- t.Fatalf("%s: f.Process failed: %s", event.name, err)
- }
- if done != event.expectDone {
- t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
- }
- }
- if got, want := f.memSize, event.memSizeAfterEvent; got != want {
- t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want)
- }
- }
- })
- }
-}
-
-func TestMemoryLimits(t *testing.T) {
- lowLimit := pkt(1, "0").MemSize()
- highLimit := 3 * lowLimit // Allow at most 3 such packets.
- f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil)
- // Send first fragment with id = 0.
- if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
- t.Fatal(err)
- }
- // Send first fragment with id = 1.
- if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil {
- t.Fatal(err)
- }
- // Send first fragment with id = 2.
- if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil {
- t.Fatal(err)
- }
-
- // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
- // evicted.
- if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil {
- t.Fatal(err)
- }
-
- if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
- t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
- }
- if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok {
- t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
- }
- if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok {
- t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
- }
-}
-
-func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- memSize := pkt(1, "0").MemSize()
- f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil)
- // Send first fragment with id = 0.
- if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
- t.Fatal(err)
- }
- // Send the same packet again.
- if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
- t.Fatal(err)
- }
-
- if got, want := f.memSize, memSize; got != want {
- t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
- }
-}
-
-func TestErrors(t *testing.T) {
- tests := []struct {
- name string
- blockSize uint16
- first uint16
- last uint16
- more bool
- data string
- err error
- }{
- {
- name: "exact block size without more",
- blockSize: 2,
- first: 2,
- last: 3,
- more: false,
- data: "01",
- },
- {
- name: "exact block size with more",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "01",
- },
- {
- name: "exact block size with more and extra data",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "012",
- err: ErrInvalidArgs,
- },
- {
- name: "exact block size with more and too little data",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "0",
- err: ErrInvalidArgs,
- },
- {
- name: "not exact block size with more",
- blockSize: 2,
- first: 2,
- last: 2,
- more: true,
- data: "0",
- err: ErrInvalidArgs,
- },
- {
- name: "not exact block size without more",
- blockSize: 2,
- first: 2,
- last: 2,
- more: false,
- data: "0",
- },
- {
- name: "first not a multiple of block size",
- blockSize: 2,
- first: 3,
- last: 4,
- more: true,
- data: "01",
- err: ErrInvalidArgs,
- },
- {
- name: "first more than last",
- blockSize: 2,
- first: 4,
- last: 3,
- more: true,
- data: "01",
- err: ErrInvalidArgs,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil)
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data))
- if !errors.Is(err, test.err) {
- t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
- }
- if done {
- t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data)
- }
- })
- }
-}
-
-type fragmentInfo struct {
- remaining int
- copied int
- offset int
- more bool
-}
-
-func TestPacketFragmenter(t *testing.T) {
- const (
- reserve = 60
- proto = 0
- )
-
- tests := []struct {
- name string
- fragmentPayloadLen uint32
- transportHeaderLen int
- payloadSize int
- wantFragments []fragmentInfo
- }{
- {
- name: "Packet exactly fits in MTU",
- fragmentPayloadLen: 1280,
- transportHeaderLen: 0,
- payloadSize: 1280,
- wantFragments: []fragmentInfo{
- {remaining: 0, copied: 1280, offset: 0, more: false},
- },
- },
- {
- name: "Packet exactly does not fit in MTU",
- fragmentPayloadLen: 1000,
- transportHeaderLen: 0,
- payloadSize: 1001,
- wantFragments: []fragmentInfo{
- {remaining: 1, copied: 1000, offset: 0, more: true},
- {remaining: 0, copied: 1, offset: 1000, more: false},
- },
- },
- {
- name: "Packet has a transport header",
- fragmentPayloadLen: 560,
- transportHeaderLen: 40,
- payloadSize: 560,
- wantFragments: []fragmentInfo{
- {remaining: 1, copied: 560, offset: 0, more: true},
- {remaining: 0, copied: 40, offset: 560, more: false},
- },
- },
- {
- name: "Packet has a huge transport header",
- fragmentPayloadLen: 500,
- transportHeaderLen: 1300,
- payloadSize: 500,
- wantFragments: []fragmentInfo{
- {remaining: 3, copied: 500, offset: 0, more: true},
- {remaining: 2, copied: 500, offset: 500, more: true},
- {remaining: 1, copied: 500, offset: 1000, more: true},
- {remaining: 0, copied: 300, offset: 1500, more: false},
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
- originalPayload := stack.PayloadSince(pkt.TransportHeader())
- var reassembledPayload buffer.VectorisedView
- pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
- for i := 0; ; i++ {
- fragPkt, offset, copied, more := pf.BuildNextFragment()
- wantFragment := test.wantFragments[i]
- if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
- t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
- }
- if copied != wantFragment.copied {
- t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
- }
- if offset != wantFragment.offset {
- t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
- }
- if more != wantFragment.more {
- t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
- }
- if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen {
- t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen)
- }
- if got := fragPkt.AvailableHeaderBytes(); got != reserve {
- t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
- }
- if got := fragPkt.TransportHeader().View().Size(); got != 0 {
- t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
- }
- reassembledPayload.AppendViews(fragPkt.Data().Views())
- if !more {
- if i != len(test.wantFragments)-1 {
- t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
- }
- break
- }
- }
- if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" {
- t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-type testTimeoutHandler struct {
- pkt *stack.PacketBuffer
-}
-
-func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
- h.pkt = pkt
-}
-
-func TestTimeoutHandler(t *testing.T) {
- const (
- proto = 99
- )
-
- pk1 := pkt(1, "1")
- pk2 := pkt(1, "2")
-
- type processParam struct {
- first uint16
- last uint16
- more bool
- pkt *stack.PacketBuffer
- }
-
- tests := []struct {
- name string
- params []processParam
- wantError bool
- wantPkt *stack.PacketBuffer
- }{
- {
- name: "onTimeout runs",
- params: []processParam{
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk1,
- },
- },
- wantError: false,
- wantPkt: pk1,
- },
- {
- name: "no first fragment",
- params: []processParam{
- {
- first: 1,
- last: 1,
- more: true,
- pkt: pk1,
- },
- },
- wantError: false,
- wantPkt: nil,
- },
- {
- name: "second pkt is ignored",
- params: []processParam{
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk1,
- },
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk2,
- },
- },
- wantError: false,
- wantPkt: pk1,
- },
- {
- name: "invalid args - first is greater than last",
- params: []processParam{
- {
- first: 1,
- last: 0,
- more: true,
- pkt: pk1,
- },
- },
- wantError: true,
- wantPkt: nil,
- },
- }
-
- id := FragmentID{ID: 0}
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- handler := &testTimeoutHandler{pkt: nil}
-
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler)
-
- for _, p := range test.params {
- if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError {
- t.Errorf("f.Process error = %s", err)
- }
- }
- if !test.wantError {
- r, ok := f.reassemblers[id]
- if !ok {
- t.Fatal("Reassembler not found")
- }
- f.release(r, true)
- }
- switch {
- case handler.pkt != nil && test.wantPkt == nil:
- t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView())
- case handler.pkt == nil && test.wantPkt != nil:
- t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView())
- case handler.pkt != nil && test.wantPkt != nil:
- if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" {
- t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff)
- }
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_list.go b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go
new file mode 100644
index 000000000..673bb11b0
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go
@@ -0,0 +1,221 @@
+package fragmentation
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type reassemblerElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type reassemblerList struct {
+ head *reassembler
+ tail *reassembler
+}
+
+// Reset resets list l to the empty state.
+func (l *reassemblerList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *reassemblerList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *reassemblerList) Front() *reassembler {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *reassemblerList) Back() *reassembler {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *reassemblerList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (reassemblerElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *reassemblerList) PushFront(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *reassemblerList) PushBack(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *reassemblerList) PushBackList(m *reassemblerList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *reassemblerList) InsertAfter(b, e *reassembler) {
+ bLinker := reassemblerElementMapper{}.linkerFor(b)
+ eLinker := reassemblerElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *reassemblerList) InsertBefore(a, e *reassembler) {
+ aLinker := reassemblerElementMapper{}.linkerFor(a)
+ eLinker := reassemblerElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *reassemblerList) Remove(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ reassemblerElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ reassemblerElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type reassemblerEntry struct {
+ next *reassembler
+ prev *reassembler
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) Next() *reassembler {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) Prev() *reassembler {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) SetNext(elem *reassembler) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) SetPrev(elem *reassembler) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
deleted file mode 100644
index cfd9f00ef..000000000
--- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
+++ /dev/null
@@ -1,233 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "bytes"
- "math"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type processParams struct {
- first uint16
- last uint16
- more bool
- pkt *stack.PacketBuffer
- wantDone bool
- wantError error
-}
-
-func TestReassemblerProcess(t *testing.T) {
- const proto = 99
-
- v := func(size int) buffer.View {
- payload := buffer.NewView(size)
- for i := 1; i < size; i++ {
- payload[i] = uint8(i) * 3
- }
- return payload
- }
-
- pkt := func(sizes ...int) *stack.PacketBuffer {
- var vv buffer.VectorisedView
- for _, size := range sizes {
- vv.AppendView(v(size))
- }
- return stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- }
-
- var tests = []struct {
- name string
- params []processParams
- want []hole
- wantPkt *stack.PacketBuffer
- }{
- {
- name: "No fragments",
- params: nil,
- want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
- },
- {
- name: "One fragment at beginning",
- params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)},
- {first: 2, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "One fragment in the middle",
- params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)},
- {first: 0, last: 0, filled: false, final: false},
- {first: 3, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "One fragment at the end",
- params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)},
- {first: 0, last: 0, filled: false},
- },
- },
- {
- name: "One fragment completing a packet",
- params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
- want: []hole{
- {first: 0, last: 1, filled: true, final: true},
- },
- wantPkt: pkt(2),
- },
- {
- name: "Two fragments completing a packet",
- params: []processParams{
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 1, filled: true, final: false},
- {first: 2, last: 3, filled: true, final: true},
- },
- wantPkt: pkt(2, 2),
- },
- {
- name: "Two fragments completing a packet with a duplicate",
- params: []processParams{
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 1, filled: true, final: false},
- {first: 2, last: 3, filled: true, final: true},
- },
- wantPkt: pkt(2, 2),
- },
- {
- name: "Two fragments completing a packet with a partial duplicate",
- params: []processParams{
- {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
- {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 3, filled: true, final: false},
- {first: 4, last: 5, filled: true, final: true},
- },
- wantPkt: pkt(4, 2),
- },
- {
- name: "Two overlapping fragments",
- params: []processParams{
- {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
- {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
- },
- want: []hole{
- {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)},
- {first: 11, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "Two final fragments with different ends",
- params: []processParams{
- {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
- {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
- },
- want: []hole{
- {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)},
- {first: 0, last: 9, filled: false, final: false},
- },
- },
- {
- name: "Two final fragments - duplicate",
- params: []processParams{
- {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
- {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
- },
- want: []hole{
- {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
- {first: 0, last: 4, filled: false, final: false},
- },
- },
- {
- name: "Two final fragments - duplicate, with different ends",
- params: []processParams{
- {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
- {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
- },
- want: []hole{
- {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
- {first: 0, last: 4, filled: false, final: false},
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- var resPkt *stack.PacketBuffer
- var isDone bool
- for _, param := range test.params {
- pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
- if done != param.wantDone || err != param.wantError {
- t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
- }
- if done {
- resPkt = pkt
- isDone = true
- }
- }
-
- ignorePkt := func(a, b *stack.PacketBuffer) bool { return true }
- cmpPktData := func(a, b *stack.PacketBuffer) bool {
- if a == nil || b == nil {
- return a == b
- }
- return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView())
- }
-
- if isDone {
- if diff := cmp.Diff(
- test.want, r.holes,
- cmp.AllowUnexported(hole{}),
- // Do not compare pkt in hole. Data will be altered.
- cmp.Comparer(ignorePkt),
- ); diff != "" {
- t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" {
- t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff)
- }
- } else {
- if diff := cmp.Diff(
- test.want, r.holes,
- cmp.AllowUnexported(hole{}),
- cmp.Comparer(cmpPktData),
- ); diff != "" {
- t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
- }
- }
- })
- }
-}