diff options
author | gVisor bot <gvisor-bot@google.com> | 2020-04-04 01:39:59 +0000 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-04-04 01:39:59 +0000 |
commit | 078753e0fe85dbcc047aaa5607a2bbc209491672 (patch) | |
tree | 14790716c79a7318214f1dc9293b6158a47851c3 /pkg | |
parent | 818a047ab6deabe5a75a5452cdb950cc0e22d722 (diff) | |
parent | fc99a7ebf0c24b6f7b3cfd6351436373ed54548b (diff) |
Merge release-20200323.0-69-gfc99a7e (automated)
Diffstat (limited to 'pkg')
48 files changed, 728 insertions, 318 deletions
diff --git a/pkg/buffer/buffer_list.go b/pkg/buffer/buffer_list.go index e2d519538..15c7a74cc 100755 --- a/pkg/buffer/buffer_list.go +++ b/pkg/buffer/buffer_list.go @@ -52,12 +52,21 @@ func (l *bufferList) Back() *buffer { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *bufferList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *bufferList) PushFront(e *buffer) { linker := bufferElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { bufferElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *bufferList) PushBack(e *buffer) { linker := bufferElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { bufferElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *bufferList) PushBackList(m *bufferList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/ilist/interface_list.go b/pkg/ilist/interface_list.go index aeb636f52..cabf46e3f 100755 --- a/pkg/ilist/interface_list.go +++ b/pkg/ilist/interface_list.go @@ -71,12 +71,21 @@ func (l *List) Back() Element { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *List) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *List) PushFront(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { ElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -91,7 +100,6 @@ func (l *List) PushBack(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { ElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -112,7 +120,6 @@ func (l *List) PushBackList(m *List) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/refs/weak_ref_list.go b/pkg/refs/weak_ref_list.go index 1d0ae2099..90433fb28 100755 --- a/pkg/refs/weak_ref_list.go +++ b/pkg/refs/weak_ref_list.go @@ -52,12 +52,21 @@ func (l *weakRefList) Back() *WeakRef { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *weakRefList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *weakRefList) PushFront(e *WeakRef) { linker := weakRefElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { weakRefElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *weakRefList) PushBack(e *WeakRef) { linker := weakRefElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { weakRefElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *weakRefList) PushBackList(m *weakRefList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/fs/dirent_list.go b/pkg/sentry/fs/dirent_list.go index acdce100c..ecbbd7883 100755 --- a/pkg/sentry/fs/dirent_list.go +++ b/pkg/sentry/fs/dirent_list.go @@ -52,12 +52,21 @@ func (l *direntList) Back() *Dirent { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *direntList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *direntList) PushFront(e *Dirent) { linker := direntElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { direntElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *direntList) PushBack(e *Dirent) { linker := direntElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { direntElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *direntList) PushBackList(m *direntList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/fs/event_list.go b/pkg/sentry/fs/event_list.go index 0274f41a2..167fdf906 100755 --- a/pkg/sentry/fs/event_list.go +++ b/pkg/sentry/fs/event_list.go @@ -52,12 +52,21 @@ func (l *eventList) Back() *Event { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *eventList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *eventList) PushFront(e *Event) { linker := eventElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { eventElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *eventList) PushBack(e *Event) { linker := eventElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { eventElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *eventList) PushBackList(m *eventList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/fsimpl/kernfs/slot_list.go b/pkg/sentry/fsimpl/kernfs/slot_list.go index 5c8020e11..09c30bca7 100755 --- a/pkg/sentry/fsimpl/kernfs/slot_list.go +++ b/pkg/sentry/fsimpl/kernfs/slot_list.go @@ -52,12 +52,21 @@ func (l *slotList) Back() *slot { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *slotList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *slotList) PushFront(e *slot) { linker := slotElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { slotElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *slotList) PushBack(e *slot) { linker := slotElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { slotElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *slotList) PushBackList(m *slotList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/epoll/epoll_list.go b/pkg/sentry/kernel/epoll/epoll_list.go index 37f757fa8..a018f7b5c 100755 --- a/pkg/sentry/kernel/epoll/epoll_list.go +++ b/pkg/sentry/kernel/epoll/epoll_list.go @@ -52,12 +52,21 @@ func (l *pollEntryList) Back() *pollEntry { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *pollEntryList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *pollEntryList) PushFront(e *pollEntry) { linker := pollEntryElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { pollEntryElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *pollEntryList) PushBack(e *pollEntry) { linker := pollEntryElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { pollEntryElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *pollEntryList) PushBackList(m *pollEntryList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/futex/waiter_list.go b/pkg/sentry/kernel/futex/waiter_list.go index 204eededf..1b7a92b62 100755 --- a/pkg/sentry/kernel/futex/waiter_list.go +++ b/pkg/sentry/kernel/futex/waiter_list.go @@ -52,12 +52,21 @@ func (l *waiterList) Back() *Waiter { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *waiterList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *waiterList) PushFront(e *Waiter) { linker := waiterElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { waiterElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *waiterList) PushBack(e *Waiter) { linker := waiterElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { waiterElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *waiterList) PushBackList(m *waiterList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0a448b57c..2e6f42b92 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -564,15 +564,25 @@ func (ts *TaskSet) unregisterEpollWaiters() { ts.mu.RLock() defer ts.mu.RUnlock() + + // Tasks that belong to the same process could potentially point to the + // same FDTable. So we retain a map of processed ones to avoid + // processing the same FDTable multiple times. + processed := make(map[*FDTable]struct{}) for t := range ts.Root.tids { // We can skip locking Task.mu here since the kernel is paused. - if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - if e, ok := file.FileOperations.(*epoll.EventPoll); ok { - e.UnregisterEpollWaiters() - } - }) + if t.fdTable == nil { + continue + } + if _, ok := processed[t.fdTable]; ok { + continue } + t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + if e, ok := file.FileOperations.(*epoll.EventPoll); ok { + e.UnregisterEpollWaiters() + } + }) + processed[t.fdTable] = struct{}{} } } diff --git a/pkg/sentry/kernel/pending_signals_list.go b/pkg/sentry/kernel/pending_signals_list.go index 8eb40ac2c..2685c631a 100755 --- a/pkg/sentry/kernel/pending_signals_list.go +++ b/pkg/sentry/kernel/pending_signals_list.go @@ -52,12 +52,21 @@ func (l *pendingSignalList) Back() *pendingSignal { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *pendingSignalList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *pendingSignalList) PushFront(e *pendingSignal) { linker := pendingSignalElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { pendingSignalElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *pendingSignalList) PushBack(e *pendingSignal) { linker := pendingSignalElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *pendingSignalList) PushBackList(m *pendingSignalList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/process_group_list.go b/pkg/sentry/kernel/process_group_list.go index 40c1a13a4..3c5ea8aa7 100755 --- a/pkg/sentry/kernel/process_group_list.go +++ b/pkg/sentry/kernel/process_group_list.go @@ -52,12 +52,21 @@ func (l *processGroupList) Back() *ProcessGroup { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *processGroupList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *processGroupList) PushFront(e *ProcessGroup) { linker := processGroupElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { processGroupElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *processGroupList) PushBack(e *ProcessGroup) { linker := processGroupElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { processGroupElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *processGroupList) PushBackList(m *processGroupList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/semaphore/waiter_list.go b/pkg/sentry/kernel/semaphore/waiter_list.go index 27120afe3..4bfe5ff95 100755 --- a/pkg/sentry/kernel/semaphore/waiter_list.go +++ b/pkg/sentry/kernel/semaphore/waiter_list.go @@ -52,12 +52,21 @@ func (l *waiterList) Back() *waiter { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *waiterList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *waiterList) PushFront(e *waiter) { linker := waiterElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { waiterElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *waiterList) PushBack(e *waiter) { linker := waiterElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { waiterElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *waiterList) PushBackList(m *waiterList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/session_list.go b/pkg/sentry/kernel/session_list.go index 8174f413d..768482ab6 100755 --- a/pkg/sentry/kernel/session_list.go +++ b/pkg/sentry/kernel/session_list.go @@ -52,12 +52,21 @@ func (l *sessionList) Back() *Session { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *sessionList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *sessionList) PushFront(e *Session) { linker := sessionElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { sessionElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *sessionList) PushBack(e *Session) { linker := sessionElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { sessionElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *sessionList) PushBackList(m *sessionList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/socket_list.go b/pkg/sentry/kernel/socket_list.go index ac93e2365..294aa99fe 100755 --- a/pkg/sentry/kernel/socket_list.go +++ b/pkg/sentry/kernel/socket_list.go @@ -52,12 +52,21 @@ func (l *socketList) Back() *SocketEntry { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *socketList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *socketList) PushFront(e *SocketEntry) { linker := socketElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { socketElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *socketList) PushBack(e *SocketEntry) { linker := socketElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { socketElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *socketList) PushBackList(m *socketList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/task_list.go b/pkg/sentry/kernel/task_list.go index 4dfcdbf2c..e7a3a3d20 100755 --- a/pkg/sentry/kernel/task_list.go +++ b/pkg/sentry/kernel/task_list.go @@ -52,12 +52,21 @@ func (l *taskList) Back() *Task { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *taskList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *taskList) PushFront(e *Task) { linker := taskElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { taskElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *taskList) PushBack(e *Task) { linker := taskElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { taskElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *taskList) PushBackList(m *taskList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/mm/io_list.go b/pkg/sentry/mm/io_list.go index 287e4305c..ae0f19fc5 100755 --- a/pkg/sentry/mm/io_list.go +++ b/pkg/sentry/mm/io_list.go @@ -52,12 +52,21 @@ func (l *ioList) Back() *ioResult { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *ioList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *ioList) PushFront(e *ioResult) { linker := ioElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { ioElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *ioList) PushBack(e *ioResult) { linker := ioElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { ioElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *ioList) PushBackList(m *ioList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/socket/unix/transport/transport_message_list.go b/pkg/sentry/socket/unix/transport/transport_message_list.go index 9edc731b4..568cd5871 100755 --- a/pkg/sentry/socket/unix/transport/transport_message_list.go +++ b/pkg/sentry/socket/unix/transport/transport_message_list.go @@ -52,12 +52,21 @@ func (l *messageList) Back() *message { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *messageList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *messageList) PushFront(e *message) { linker := messageElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { messageElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *messageList) PushBack(e *message) { linker := messageElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { messageElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *messageList) PushBackList(m *messageList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/vfs/epoll_interest_list.go b/pkg/sentry/vfs/epoll_interest_list.go index 1bd41f400..67104a9c6 100755 --- a/pkg/sentry/vfs/epoll_interest_list.go +++ b/pkg/sentry/vfs/epoll_interest_list.go @@ -52,12 +52,21 @@ func (l *epollInterestList) Back() *epollInterest { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *epollInterestList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *epollInterestList) PushFront(e *epollInterest) { linker := epollInterestElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { epollInterestElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *epollInterestList) PushBack(e *epollInterest) { linker := epollInterestElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { epollInterestElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *epollInterestList) PushBackList(m *epollInterestList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8d42cd066..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -17,6 +17,7 @@ package buffer import ( "bytes" + "io" ) // View is a slice of a buffer, with convenience methods. @@ -89,6 +90,47 @@ func (vv *VectorisedView) TrimFront(count int) { } } +// Read implements io.Reader. +func (vv *VectorisedView) Read(v View) (copied int, err error) { + count := len(v) + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + copy(v[copied:], vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return copied, nil + } + count -= len(vv.views[0]) + copy(v[copied:], vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + if copied == 0 { + return 0, io.EOF + } + return copied, nil +} + +// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It +// returns the number of bytes copied. +func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) { + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + dstVV.AppendView(vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return + } + count -= len(vv.views[0]) + dstVV.AppendView(vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + return copied +} + // CapLength irreversibly reduces the length of the vectorised view. func (vv *VectorisedView) CapLength(length int) { if length < 0 { @@ -116,12 +158,12 @@ func (vv *VectorisedView) CapLength(length int) { // Clone returns a clone of this VectorisedView. // If the buffer argument is large enough to contain all the Views of this VectorisedView, // the method will avoid allocations and use the buffer to store the Views of the clone. -func (vv VectorisedView) Clone(buffer []View) VectorisedView { +func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } // First returns the first view of the vectorised view. -func (vv VectorisedView) First() View { +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { return nil } @@ -134,11 +176,12 @@ func (vv *VectorisedView) RemoveFirst() { return } vv.size -= len(vv.views[0]) + vv.views[0] = nil vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. -func (vv VectorisedView) Size() int { +func (vv *VectorisedView) Size() int { return vv.size } @@ -146,7 +189,7 @@ func (vv VectorisedView) Size() int { // // If the vectorised view contains a single view, that view will be returned // directly. -func (vv VectorisedView) ToView() View { +func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } @@ -158,7 +201,7 @@ func (vv VectorisedView) ToView() View { } // Views returns the slice containing the all views. -func (vv VectorisedView) Views() []View { +func (vv *VectorisedView) Views() []View { return vv.views } diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a8d6653ce..b4a0ae53d 100755 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -28,7 +28,7 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Pkt stack.PacketBuffer + Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO Route stack.Route @@ -257,7 +257,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne route := r.Clone() route.Release() p := PacketInfo{ - Pkt: pkt, + Pkt: &pkt, Proto: protocol, GSO: gso, Route: route, @@ -269,21 +269,15 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() - payloadView := pkts[0].Data.ToView() n := 0 - for _, pkt := range pkts { - off := pkt.DataOffset - size := pkt.DataSize + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ - Pkt: stack.PacketBuffer{ - Header: pkt.Header, - Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), - }, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, @@ -301,7 +295,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: stack.PacketBuffer{Data: vv}, + Pkt: &stack.PacketBuffer{Data: vv}, Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 3b3b6909b..7198742b7 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -441,118 +441,106 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - var ethHdrBuf []byte - // hdr + data - iovLen := 2 - if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, - } - - // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ - } +// +// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD +// picked to write the packet out has to be the same for all packets in the +// list. In other words all packets in the batch should belong to the same +// flow. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := pkts.Len() - n := len(pkts) - - views := pkts[0].Data.Views() - /* - * Each boundary in views can add one more iovec. - * - * payload | | | | - * ----------------------------- - * packets | | | | | | | - * ----------------------------- - * iovecs | | | | | | | | | - */ - iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) mmsgHdrs := make([]rawfile.MMsgHdr, n) + i := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + var ethHdrBuf []byte + iovLen := 0 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } - iovecIdx := 0 - viewIdx := 0 - viewOff := 0 - off := 0 - nextOff := 0 - for i := range pkts { - // TODO(b/134618279): Different packets may have different data - // in the future. We should handle this. - if !viewsEqual(pkts[i].Data.Views(), views) { - panic("All packets in pkts should have the same Data.") + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ } - prevIovecIdx := iovecIdx - mmsgHdr := &mmsgHdrs[i] - mmsgHdr.Msg.Iov = &iovec[iovecIdx] - packetSize := pkts[i].DataSize - hdr := &pkts[i].Header - - off = pkts[i].DataOffset - if off != nextOff { - // We stop in a different point last time. - size := packetSize - viewIdx = 0 - viewOff = 0 - for size > 0 { - if size >= len(views[viewIdx]) { - viewIdx++ - viewOff = 0 - size -= len(views[viewIdx]) - } else { - viewOff = size - size = 0 + var vnetHdrBuf []byte + vnetHdr := virtioNetHdr{} + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if gso != nil { + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + if gso.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen + vnetHdr.csumOffset = gso.CsumOffset + } + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { + switch gso.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + } + vnetHdr.gsoSize = gso.MSS } } + vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr) + iovLen++ } - nextOff = off + packetSize + iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovecs[0] + iovecIdx := 0 + if vnetHdrBuf != nil { + v := &iovecs[iovecIdx] + v.Base = &vnetHdrBuf[0] + v.Len = uint64(len(vnetHdrBuf)) + iovecIdx++ + } if ethHdrBuf != nil { - v := &iovec[iovecIdx] + v := &iovecs[iovecIdx] v.Base = ðHdrBuf[0] v.Len = uint64(len(ethHdrBuf)) iovecIdx++ } - - v := &iovec[iovecIdx] + pktSize := uint64(0) + // Encode L3 Header + v := &iovecs[iovecIdx] + hdr := &pkt.Header hdrView := hdr.View() v.Base = &hdrView[0] v.Len = uint64(len(hdrView)) + pktSize += v.Len iovecIdx++ - for packetSize > 0 { - vec := &iovec[iovecIdx] + // Now encode the Transport Payload. + pktViews := pkt.Data.Views() + for i := range pktViews { + vec := &iovecs[iovecIdx] iovecIdx++ - - v := views[viewIdx] - vec.Base = &v[viewOff] - s := len(v) - viewOff - if s <= packetSize { - viewIdx++ - viewOff = 0 - } else { - s = packetSize - viewOff += s - } - vec.Len = uint64(s) - packetSize -= s + vec.Base = &pktViews[i][0] + vec.Len = uint64(len(pktViews[i])) + pktSize += vec.Len } - - mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + i++ } packets := 0 for packets < n { - fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))] + fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))] sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) if err != nil { return packets, err diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 4039753b7..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index f5973066d..a5478ce17 100755 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,7 +87,7 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6461d0108..0796d717e 100755 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0a6b8945c..062388f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -200,7 +200,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, pkt.Header.View(), gso) } @@ -233,20 +233,16 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket(gso, protocol, pkt) + e.dumpPacket(gso, protocol, &pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - view := pkts[0].Data.ToView() - for _, pkt := range pkts { - e.dumpPacket(gso, protocol, stack.PacketBuffer{ - Header: pkt.Header, - Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), - }) +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.dumpPacket(gso, protocol, pkt) } return e.lower.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 52fe397bf..2b3741276 100755 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -112,9 +112,9 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { - return len(pkts), nil + return pkts.Len(), nil } n, err := e.lower.WritePackets(r, gso, pkts, protocol) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 255098372..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/fragmentation/reassembler_list.go b/pkg/tcpip/network/fragmentation/reassembler_list.go index a48422c97..b5d40bd25 100755 --- a/pkg/tcpip/network/fragmentation/reassembler_list.go +++ b/pkg/tcpip/network/fragmentation/reassembler_list.go @@ -52,12 +52,21 @@ func (l *reassemblerList) Back() *reassembler { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *reassemblerList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *reassemblerList) PushFront(e *reassembler) { linker := reassemblerElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *reassemblerList) PushBack(e *reassembler) { linker := reassemblerElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *reassemblerList) PushBackList(m *reassemblerList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a7d9a8b25..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -280,28 +280,47 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil + } + + for pkt := pkts.Front(); pkt != nil; { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + pkt = pkt.Next() } // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - for i := range pkts { - if ok := ipt.Check(stack.Output, pkts[i]); !ok { - // iptables is telling us to drop the packet. + dropped := ipt.CheckPackets(stack.Output, pkts) + if len(dropped) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow Path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { continue } - ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) - pkts[i].NetworkHeader = buffer.View(ip) + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacket writes a packet already containing a network diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6d2d2c034..f91180aa3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -79,7 +79,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Only the first view in vv is accounted for by h. To account for the // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. - payload := pkt.Data + payload := pkt.Data.Clone(nil) payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b462b8604..a815b4d9b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -143,19 +143,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil } - for i := range pkts { - hdr := &pkts[i].Header - size := pkts[i].DataSize - ip := e.addIPHeader(r, hdr, size, params) - pkts[i].NetworkHeader = buffer.View(ip) + for pb := pkts.Front(); pb != nil; pb = pb.Next() { + ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) + pb.NetworkHeader = buffer.View(ip) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 37907ae24..6c0a4b24d 100755 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -209,6 +209,23 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { return true } +// CheckPackets runs pkts through the rules for hook and returns a map of packets that +// should not go forward. +// +// NOTE: unlike the Check API the returned map contains packets that should be +// dropped. +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if ok := it.Check(hook, *pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + } + return drop +} + // Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us diff --git a/pkg/tcpip/stack/linkaddrentry_list.go b/pkg/tcpip/stack/linkaddrentry_list.go index 6697281cd..43022f9b3 100755 --- a/pkg/tcpip/stack/linkaddrentry_list.go +++ b/pkg/tcpip/stack/linkaddrentry_list.go @@ -52,12 +52,21 @@ func (l *linkAddrEntryList) Back() *linkAddrEntry { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *linkAddrEntryList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *linkAddrEntryList) PushFront(e *linkAddrEntry) { linker := linkAddrEntryElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { linkAddrEntryElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *linkAddrEntryList) PushBack(e *linkAddrEntry) { linker := linkAddrEntryElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { linkAddrEntryElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *linkAddrEntryList) PushBackList(m *linkAddrEntryList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9367de180..dc125f25e 100755 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -23,9 +23,11 @@ import ( // As a PacketBuffer traverses up the stack, it may be necessary to pass it to // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. -// -// +stateify savable type PacketBuffer struct { + // PacketBufferEntry is used to build an intrusive list of + // PacketBuffers. + PacketBufferEntry + // Data holds the payload of the packet. For inbound packets, it also // holds the headers, which are consumed as the packet moves up the // stack. Headers are guaranteed not to be split across views. @@ -34,14 +36,6 @@ type PacketBuffer struct { // or otherwise modified. Data buffer.VectorisedView - // DataOffset is used for GSO output. It is the offset into the Data - // field where the payload of this packet starts. - DataOffset int - - // DataSize is used for GSO output. It is the size of this packet's - // payload. - DataSize int - // Header holds the headers of outbound packets. As a packet is passed // down the stack, each layer adds to Header. Header buffer.Prependable diff --git a/pkg/tcpip/stack/packet_buffer_list.go b/pkg/tcpip/stack/packet_buffer_list.go new file mode 100755 index 000000000..460c74c5a --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_list.go @@ -0,0 +1,193 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type PacketBufferElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (PacketBufferElementMapper) linkerFor(elem *PacketBuffer) *PacketBuffer { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type PacketBufferList struct { + head *PacketBuffer + tail *PacketBuffer +} + +// Reset resets list l to the empty state. +func (l *PacketBufferList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *PacketBufferList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *PacketBufferList) Front() *PacketBuffer { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *PacketBufferList) Back() *PacketBuffer { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *PacketBufferList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +func (l *PacketBufferList) PushFront(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + PacketBufferElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *PacketBufferList) PushBack(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *PacketBufferList) PushBackList(m *PacketBufferList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(m.head) + PacketBufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *PacketBufferList) InsertAfter(b, e *PacketBuffer) { + bLinker := PacketBufferElementMapper{}.linkerFor(b) + eLinker := PacketBufferElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + PacketBufferElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *PacketBufferList) InsertBefore(a, e *PacketBuffer) { + aLinker := PacketBufferElementMapper{}.linkerFor(a) + eLinker := PacketBufferElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + PacketBufferElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *PacketBufferList) Remove(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + PacketBufferElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + PacketBufferElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type PacketBufferEntry struct { + next *PacketBuffer + prev *PacketBuffer +} + +// Next returns the entry that follows e in the list. +func (e *PacketBufferEntry) Next() *PacketBuffer { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *PacketBufferEntry) Prev() *PacketBuffer { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *PacketBufferEntry) SetNext(elem *PacketBuffer) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *PacketBufferEntry) SetPrev(elem *PacketBuffer) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go deleted file mode 100755 index 0c6b7924c..000000000 --- a/pkg/tcpip/stack/packet_buffer_state.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// beforeSave is invoked by stateify. -func (pk *PacketBuffer) beforeSave() { - // Non-Data fields may be slices of the Data field. This causes - // problems for SR, so during save we make each header independent. - pk.Header = pk.Header.DeepCopy() - pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) - pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) - pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) -} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ac043b722..23ca9ee03 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -246,7 +246,7 @@ type NetworkEndpoint interface { // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. @@ -393,7 +393,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9fbe8a411..a0e5e0300 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -168,23 +168,26 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuff return err } -// WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +// WritePackets writes a list of n packets through the given route and returns +// the number of packets written. +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - payloadSize := 0 - for i := 0; i < n; i++ { - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) - payloadSize += pkts[i].DataSize + + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Header.UsedLength() + writtenBytes += pb.Data.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go index 7f038a856..6e98e8e95 100755 --- a/pkg/tcpip/stack/stack_state_autogen.go +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -32,30 +32,30 @@ func (x *linkAddrEntryEntry) load(m state.Map) { m.Load("prev", &x.prev) } -func (x *PacketBuffer) save(m state.Map) { +func (x *PacketBufferList) beforeSave() {} +func (x *PacketBufferList) save(m state.Map) { x.beforeSave() - m.Save("Data", &x.Data) - m.Save("DataOffset", &x.DataOffset) - m.Save("DataSize", &x.DataSize) - m.Save("Header", &x.Header) - m.Save("LinkHeader", &x.LinkHeader) - m.Save("NetworkHeader", &x.NetworkHeader) - m.Save("TransportHeader", &x.TransportHeader) - m.Save("Hash", &x.Hash) - m.Save("Owner", &x.Owner) -} - -func (x *PacketBuffer) afterLoad() {} -func (x *PacketBuffer) load(m state.Map) { - m.Load("Data", &x.Data) - m.Load("DataOffset", &x.DataOffset) - m.Load("DataSize", &x.DataSize) - m.Load("Header", &x.Header) - m.Load("LinkHeader", &x.LinkHeader) - m.Load("NetworkHeader", &x.NetworkHeader) - m.Load("TransportHeader", &x.TransportHeader) - m.Load("Hash", &x.Hash) - m.Load("Owner", &x.Owner) + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *PacketBufferList) afterLoad() {} +func (x *PacketBufferList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *PacketBufferEntry) beforeSave() {} +func (x *PacketBufferEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *PacketBufferEntry) afterLoad() {} +func (x *PacketBufferEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) } func (x *TransportEndpointID) beforeSave() {} @@ -147,7 +147,8 @@ func (x *multiPortEndpoint) load(m state.Map) { func init() { state.Register("pkg/tcpip/stack.linkAddrEntryList", (*linkAddrEntryList)(nil), state.Fns{Save: (*linkAddrEntryList).save, Load: (*linkAddrEntryList).load}) state.Register("pkg/tcpip/stack.linkAddrEntryEntry", (*linkAddrEntryEntry)(nil), state.Fns{Save: (*linkAddrEntryEntry).save, Load: (*linkAddrEntryEntry).load}) - state.Register("pkg/tcpip/stack.PacketBuffer", (*PacketBuffer)(nil), state.Fns{Save: (*PacketBuffer).save, Load: (*PacketBuffer).load}) + state.Register("pkg/tcpip/stack.PacketBufferList", (*PacketBufferList)(nil), state.Fns{Save: (*PacketBufferList).save, Load: (*PacketBufferList).load}) + state.Register("pkg/tcpip/stack.PacketBufferEntry", (*PacketBufferEntry)(nil), state.Fns{Save: (*PacketBufferEntry).save, Load: (*PacketBufferEntry).load}) state.Register("pkg/tcpip/stack.TransportEndpointID", (*TransportEndpointID)(nil), state.Fns{Save: (*TransportEndpointID).save, Load: (*TransportEndpointID).load}) state.Register("pkg/tcpip/stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load}) state.Register("pkg/tcpip/stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load}) diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go index ddee31adb..42d63f976 100755 --- a/pkg/tcpip/transport/icmp/icmp_packet_list.go +++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go @@ -52,12 +52,21 @@ func (l *icmpPacketList) Back() *icmpPacket { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *icmpPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *icmpPacketList) PushFront(e *icmpPacket) { linker := icmpPacketElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *icmpPacketList) PushBack(e *icmpPacket) { linker := icmpPacketElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *icmpPacketList) PushBackList(m *icmpPacketList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go index ad27c7c06..5231b066f 100755 --- a/pkg/tcpip/transport/packet/packet_list.go +++ b/pkg/tcpip/transport/packet/packet_list.go @@ -52,12 +52,21 @@ func (l *packetList) Back() *packet { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *packetList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *packetList) PushFront(e *packet) { linker := packetElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { packetElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *packetList) PushBack(e *packet) { linker := packetElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { packetElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *packetList) PushBackList(m *packetList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go index e8c1bc997..15a8c845b 100755 --- a/pkg/tcpip/transport/raw/raw_packet_list.go +++ b/pkg/tcpip/transport/raw/raw_packet_list.go @@ -52,12 +52,21 @@ func (l *rawPacketList) Back() *rawPacket { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *rawPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *rawPacketList) PushFront(e *rawPacket) { linker := rawPacketElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *rawPacketList) PushBack(e *rawPacket) { linker := rawPacketElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *rawPacketList) PushBackList(m *rawPacketList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3239a5911..2ca3fb809 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -756,8 +756,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) hdr := &pkt.Header - packetSize := pkt.DataSize - off := pkt.DataOffset + packetSize := pkt.Data.Size() // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) pkt.TransportHeader = buffer.View(tcp) @@ -782,12 +781,18 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) + xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } } func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + // We need to shallow clone the VectorisedView here as ReadToView will + // split the VectorisedView and Trim underlying views as it splits. Not + // doing the clone here will cause the underlying views of data itself + // to be altered. + data = data.Clone(nil) + optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff @@ -796,31 +801,25 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso mss := int(gso.MSS) n := (data.Size() + mss - 1) / mss - // Allocate one big slice for all the headers. - hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen - buf := make([]byte, n*hdrSize) - pkts := make([]stack.PacketBuffer, n) - for i := range pkts { - pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) - } - size := data.Size() - off := 0 + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + var pkts stack.PacketBufferList for i := 0; i < n; i++ { packetSize := mss if packetSize > size { packetSize = size } size -= packetSize - pkts[i].DataOffset = off - pkts[i].DataSize = packetSize - pkts[i].Data = data - pkts[i].Hash = tf.txHash - pkts[i].Owner = owner - buildTCPHdr(r, tf, &pkts[i], gso) - off += packetSize + var pkt stack.PacketBuffer + pkt.Header = buffer.NewPrependable(hdrSize) + pkt.Hash = tf.txHash + pkt.Owner = owner + data.ReadToVV(&pkt.Data, packetSize) + buildTCPHdr(r, tf, &pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkts.PushBack(&pkt) } + if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } @@ -845,12 +844,10 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac } pkt := stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - DataOffset: 0, - DataSize: data.Size(), - Data: data, - Hash: tf.txHash, - Owner: owner, + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + Data: data, + Hash: tf.txHash, + Owner: owner, } buildTCPHdr(r, tf, &pkt, gso) diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index e6fe7985d..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -77,9 +77,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V id: id, route: r.Clone(), } - s.views[0] = v - s.data = buffer.NewVectorisedView(len(v), s.views[:1]) s.rcvdTime = time.Now() + if len(v) != 0 { + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + } return s } diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go index 62c042aff..fb7046d8f 100755 --- a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go +++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go @@ -52,12 +52,21 @@ func (l *endpointList) Back() *endpoint { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *endpointList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *endpointList) PushFront(e *endpoint) { linker := endpointElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { endpointElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *endpointList) PushBack(e *endpoint) { linker := endpointElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { endpointElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *endpointList) PushBackList(m *endpointList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go index 27f17f037..21638041c 100755 --- a/pkg/tcpip/transport/tcp/tcp_segment_list.go +++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go @@ -52,12 +52,21 @@ func (l *segmentList) Back() *segment { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *segmentList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *segmentList) PushFront(e *segment) { linker := segmentElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { segmentElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *segmentList) PushBack(e *segment) { linker := segmentElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { segmentElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *segmentList) PushBackList(m *segmentList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go index 2ae846eaa..a6513e1e4 100755 --- a/pkg/tcpip/transport/udp/udp_packet_list.go +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -52,12 +52,21 @@ func (l *udpPacketList) Back() *udpPacket { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *udpPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *udpPacketList) PushFront(e *udpPacket) { linker := udpPacketElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *udpPacketList) PushBack(e *udpPacket) { linker := udpPacketElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *udpPacketList) PushBackList(m *udpPacketList) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/waiter/waiter_list.go b/pkg/waiter/waiter_list.go index 07950faa4..35431f5a4 100755 --- a/pkg/waiter/waiter_list.go +++ b/pkg/waiter/waiter_list.go @@ -52,12 +52,21 @@ func (l *waiterList) Back() *Entry { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *waiterList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *waiterList) PushFront(e *Entry) { linker := waiterElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { waiterElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -72,7 +81,6 @@ func (l *waiterList) PushBack(e *Entry) { linker := waiterElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { waiterElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -93,7 +101,6 @@ func (l *waiterList) PushBackList(m *waiterList) { l.tail = m.tail } - m.head = nil m.tail = nil } |