From 7e4073af12bed2c76bc5757ef3e5fbfba75308a0 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 24 Mar 2020 09:05:06 -0700 Subject: Move tcpip.PacketBuffer and IPTables to stack package. This is a precursor to be being able to build an intrusive list of PacketBuffers for use in queuing disciplines being implemented. Updates #2214 PiperOrigin-RevId: 302677662 --- pkg/tcpip/stack/iptables.go | 311 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 pkg/tcpip/stack/iptables.go (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go new file mode 100644 index 000000000..37907ae24 --- /dev/null +++ b/pkg/tcpip/stack/iptables.go @@ -0,0 +1,311 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// Table names. +const ( + TablenameNat = "nat" + TablenameMangle = "mangle" + TablenameFilter = "filter" +) + +// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +const ( + ChainNamePrerouting = "PREROUTING" + ChainNameInput = "INPUT" + ChainNameForward = "FORWARD" + ChainNameOutput = "OUTPUT" + ChainNamePostrouting = "POSTROUTING" +) + +// HookUnset indicates that there is no hook set for an entrypoint or +// underflow. +const HookUnset = -1 + +// DefaultTables returns a default set of tables. Each chain is set to accept +// all packets. +func DefaultTables() IPTables { + // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for + // iotas. + return IPTables{ + Tables: map[string]Table{ + TablenameNat: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, + }, + Underflows: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, + }, + UserChains: map[string]int{}, + }, + TablenameMangle: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: map[Hook]int{ + Prerouting: 0, + Output: 1, + }, + UserChains: map[string]int{}, + }, + TablenameFilter: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + Underflows: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + UserChains: map[string]int{}, + }, + }, + Priorities: map[Hook][]string{ + Input: []string{TablenameNat, TablenameFilter}, + Prerouting: []string{TablenameMangle, TablenameNat}, + Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + }, + } +} + +// EmptyFilterTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyFilterTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + Underflows: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + UserChains: map[string]int{}, + } +} + +// EmptyNatTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyNatTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + Underflows: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + UserChains: map[string]int{}, + } +} + +// A chainVerdict is what a table decides should be done with a packet. +type chainVerdict int + +const ( + // chainAccept indicates the packet should continue through netstack. + chainAccept chainVerdict = iota + + // chainAccept indicates the packet should be dropped. + chainDrop + + // chainReturn indicates the packet should return to the calling chain + // or the underflow rule of a builtin chain. + chainReturn +) + +// Check runs pkt through the rules for hook. It returns true when the packet +// should continue traversing the network stack and false when it should be +// dropped. +// +// Precondition: pkt.NetworkHeader is set. +func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { + // Go through each table containing the hook. + for _, tablename := range it.Priorities[hook] { + table := it.Tables[tablename] + ruleIdx := table.BuiltinChains[hook] + switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { + // If the table returns Accept, move on to the next table. + case chainAccept: + continue + // The Drop verdict is final. + case chainDrop: + return false + case chainReturn: + // Any Return from a built-in chain means we have to + // call the underflow. + underflow := table.Rules[table.Underflows[hook]] + switch v, _ := underflow.Target.Action(pkt); v { + case RuleAccept: + continue + case RuleDrop: + return false + case RuleJump, RuleReturn: + panic("Underflows should only return RuleAccept or RuleDrop.") + default: + panic(fmt.Sprintf("Unknown verdict: %d", v)) + } + + default: + panic(fmt.Sprintf("Unknown verdict %v.", verdict)) + } + } + + // Every table returned Accept. + return true +} + +// Precondition: pkt.NetworkHeader is set. +func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { + // Start from ruleIdx and walk the list of rules until a rule gives us + // a verdict. + for ruleIdx < len(table.Rules) { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { + case RuleAccept: + return chainAccept + + case RuleDrop: + return chainDrop + + case RuleReturn: + return chainReturn + + case RuleJump: + // "Jumping" to the next rule just means we're + // continuing on down the list. + if jumpTo == ruleIdx+1 { + ruleIdx++ + continue + } + switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + case chainAccept: + return chainAccept + case chainDrop: + return chainDrop + case chainReturn: + ruleIdx++ + continue + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + } + + // We got through the entire table without a decision. Default to DROP + // for safety. + return chainDrop +} + +// Precondition: pk.NetworkHeader is set. +func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { + rule := table.Rules[ruleIdx] + + // If pkt.NetworkHeader hasn't been set yet, it will be contained in + // pkt.Data.First(). + if pkt.NetworkHeader == nil { + pkt.NetworkHeader = pkt.Data.First() + } + + // Check whether the packet matches the IP header filter. + if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + + // Go through each rule matcher. If they all match, run + // the rule target. + for _, matcher := range rule.Matchers { + matches, hotdrop := matcher.Match(hook, pkt, "") + if hotdrop { + return RuleDrop, 0 + } + if !matches { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + } + + // All the matchers matched, so run the target. + return rule.Target.Action(pkt) +} + +func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the destination IP. + dest := hdr.DestinationAddress() + matches := true + for i := range filter.Dst { + if dest[i]&filter.DstMask[i] != filter.Dst[i] { + matches = false + break + } + } + if matches == filter.DstInvert { + return false + } + + return true +} -- cgit v1.2.3 From fc99a7ebf0c24b6f7b3cfd6351436373ed54548b Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Fri, 3 Apr 2020 18:34:48 -0700 Subject: Refactor software GSO code. Software GSO implementation currently has a complicated code path with implicit assumptions that all packets to WritePackets carry same Data and it does this to avoid allocations on the path etc. But this makes it hard to reuse the WritePackets API. This change breaks all such assumptions by introducing a new Vectorised View API ReadToVV which can be used to cleanly split a VV into multiple independent VVs. Further this change also makes packet buffers linkable to form an intrusive list. This allows us to get rid of the array of packet buffers that are passed in the WritePackets API call and replace it with a list of packet buffers. While this code does introduce some more allocations in the benchmarks it doesn't cause any degradation. Updates #231 PiperOrigin-RevId: 304731742 --- pkg/ilist/list.go | 13 ++- pkg/sentry/kernel/kernel.go | 22 +++-- pkg/tcpip/buffer/view.go | 53 +++++++++- pkg/tcpip/buffer/view_test.go | 137 ++++++++++++++++++++++++++ pkg/tcpip/link/channel/channel.go | 18 ++-- pkg/tcpip/link/fdbased/endpoint.go | 162 ++++++++++++++----------------- pkg/tcpip/link/loopback/loopback.go | 2 +- pkg/tcpip/link/muxed/injectable.go | 2 +- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 14 +-- pkg/tcpip/link/waitable/waitable.go | 4 +- pkg/tcpip/link/waitable/waitable_test.go | 6 +- pkg/tcpip/network/arp/arp.go | 2 +- pkg/tcpip/network/ip_test.go | 2 +- pkg/tcpip/network/ipv4/ipv4.go | 37 +++++-- pkg/tcpip/network/ipv6/icmp.go | 2 +- pkg/tcpip/network/ipv6/ipv6.go | 12 +-- pkg/tcpip/stack/BUILD | 14 ++- pkg/tcpip/stack/forwarder_test.go | 8 +- pkg/tcpip/stack/iptables.go | 17 ++++ pkg/tcpip/stack/ndp_test.go | 2 +- pkg/tcpip/stack/packet_buffer.go | 14 +-- pkg/tcpip/stack/packet_buffer_state.go | 27 ------ pkg/tcpip/stack/registration.go | 4 +- pkg/tcpip/stack/route.go | 19 ++-- pkg/tcpip/stack/stack_test.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 47 +++++---- pkg/tcpip/transport/tcp/segment.go | 6 +- 28 files changed, 420 insertions(+), 230 deletions(-) delete mode 100644 pkg/tcpip/stack/packet_buffer_state.go (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 8f93e4d6d..0d07da3b1 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -86,12 +86,21 @@ func (l *List) Back() Element { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *List) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *List) PushFront(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { ElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -106,7 +115,6 @@ func (l *List) PushBack(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { ElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -127,7 +135,6 @@ func (l *List) PushBackList(m *List) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0a448b57c..2e6f42b92 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -564,15 +564,25 @@ func (ts *TaskSet) unregisterEpollWaiters() { ts.mu.RLock() defer ts.mu.RUnlock() + + // Tasks that belong to the same process could potentially point to the + // same FDTable. So we retain a map of processed ones to avoid + // processing the same FDTable multiple times. + processed := make(map[*FDTable]struct{}) for t := range ts.Root.tids { // We can skip locking Task.mu here since the kernel is paused. - if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - if e, ok := file.FileOperations.(*epoll.EventPoll); ok { - e.UnregisterEpollWaiters() - } - }) + if t.fdTable == nil { + continue + } + if _, ok := processed[t.fdTable]; ok { + continue } + t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + if e, ok := file.FileOperations.(*epoll.EventPoll); ok { + e.UnregisterEpollWaiters() + } + }) + processed[t.fdTable] = struct{}{} } } diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8d42cd066..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -17,6 +17,7 @@ package buffer import ( "bytes" + "io" ) // View is a slice of a buffer, with convenience methods. @@ -89,6 +90,47 @@ func (vv *VectorisedView) TrimFront(count int) { } } +// Read implements io.Reader. +func (vv *VectorisedView) Read(v View) (copied int, err error) { + count := len(v) + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + copy(v[copied:], vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return copied, nil + } + count -= len(vv.views[0]) + copy(v[copied:], vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + if copied == 0 { + return 0, io.EOF + } + return copied, nil +} + +// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It +// returns the number of bytes copied. +func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) { + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + dstVV.AppendView(vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return + } + count -= len(vv.views[0]) + dstVV.AppendView(vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + return copied +} + // CapLength irreversibly reduces the length of the vectorised view. func (vv *VectorisedView) CapLength(length int) { if length < 0 { @@ -116,12 +158,12 @@ func (vv *VectorisedView) CapLength(length int) { // Clone returns a clone of this VectorisedView. // If the buffer argument is large enough to contain all the Views of this VectorisedView, // the method will avoid allocations and use the buffer to store the Views of the clone. -func (vv VectorisedView) Clone(buffer []View) VectorisedView { +func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } // First returns the first view of the vectorised view. -func (vv VectorisedView) First() View { +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { return nil } @@ -134,11 +176,12 @@ func (vv *VectorisedView) RemoveFirst() { return } vv.size -= len(vv.views[0]) + vv.views[0] = nil vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. -func (vv VectorisedView) Size() int { +func (vv *VectorisedView) Size() int { return vv.size } @@ -146,7 +189,7 @@ func (vv VectorisedView) Size() int { // // If the vectorised view contains a single view, that view will be returned // directly. -func (vv VectorisedView) ToView() View { +func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } @@ -158,7 +201,7 @@ func (vv VectorisedView) ToView() View { } // Views returns the slice containing the all views. -func (vv VectorisedView) Views() []View { +func (vv *VectorisedView) Views() []View { return vv.views } diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index ebc3a17b7..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -233,3 +233,140 @@ func TestToClone(t *testing.T) { }) } } + +func TestVVReadToVV(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + wantBytes string + leftVV VectorisedView + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + wantBytes: "0123456789", + leftVV: vv(20, "01234567890123456789"), + }, + { + comment: "largeVV, multiple views, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + wantBytes: "123345", + leftVV: vv(7, "567", "8910"), + }, + { + comment: "smallVV (multiple views), large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + wantBytes: "123", + leftVV: vv(0, ""), + }, + { + comment: "smallVV (single view), large read", + vv: vv(1, "1"), + bytesToRead: 10, + wantBytes: "1", + leftVV: vv(0, ""), + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + wantBytes: "", + leftVV: vv(0, ""), + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + var readTo VectorisedView + inSize := tc.vv.Size() + copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead) + if got, want := copied, len(tc.wantBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc) + } + if got, want := string(readTo.ToView()), tc.wantBytes; got != want { + t.Errorf("unexpected content in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { + t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want) + } + }) + } +} + +func TestVVRead(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + readBytes string + leftBytes string + wantError bool + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + readBytes: "0123456789", + leftBytes: "01234567890123456789", + }, + { + comment: "largeVV, multiple buffers, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + readBytes: "123345", + leftBytes: "5678910", + }, + { + comment: "smallVV, large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + readBytes: "123", + leftBytes: "", + }, + { + comment: "smallVV, large read", + vv: vv(1, "1"), + bytesToRead: 10, + readBytes: "1", + leftBytes: "", + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + readBytes: "", + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + readTo := NewView(tc.bytesToRead) + inSize := tc.vv.Size() + copied, err := tc.vv.Read(readTo) + if !tc.wantError && err != nil { + t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err) + } + readTo = readTo[:copied] + if got, want := copied, len(tc.readBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(readTo), tc.readBytes; got != want { + t.Errorf("unexpected data in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want { + t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want) + } + }) + } +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a8d6653ce..b4a0ae53d 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -28,7 +28,7 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Pkt stack.PacketBuffer + Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO Route stack.Route @@ -257,7 +257,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne route := r.Clone() route.Release() p := PacketInfo{ - Pkt: pkt, + Pkt: &pkt, Proto: protocol, GSO: gso, Route: route, @@ -269,21 +269,15 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() - payloadView := pkts[0].Data.ToView() n := 0 - for _, pkt := range pkts { - off := pkt.DataOffset - size := pkt.DataSize + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ - Pkt: stack.PacketBuffer{ - Header: pkt.Header, - Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), - }, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, @@ -301,7 +295,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: stack.PacketBuffer{Data: vv}, + Pkt: &stack.PacketBuffer{Data: vv}, Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 3b3b6909b..7198742b7 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -441,118 +441,106 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - var ethHdrBuf []byte - // hdr + data - iovLen := 2 - if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, - } - - // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ - } +// +// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD +// picked to write the packet out has to be the same for all packets in the +// list. In other words all packets in the batch should belong to the same +// flow. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := pkts.Len() - n := len(pkts) - - views := pkts[0].Data.Views() - /* - * Each boundary in views can add one more iovec. - * - * payload | | | | - * ----------------------------- - * packets | | | | | | | - * ----------------------------- - * iovecs | | | | | | | | | - */ - iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) mmsgHdrs := make([]rawfile.MMsgHdr, n) + i := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + var ethHdrBuf []byte + iovLen := 0 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } - iovecIdx := 0 - viewIdx := 0 - viewOff := 0 - off := 0 - nextOff := 0 - for i := range pkts { - // TODO(b/134618279): Different packets may have different data - // in the future. We should handle this. - if !viewsEqual(pkts[i].Data.Views(), views) { - panic("All packets in pkts should have the same Data.") + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ } - prevIovecIdx := iovecIdx - mmsgHdr := &mmsgHdrs[i] - mmsgHdr.Msg.Iov = &iovec[iovecIdx] - packetSize := pkts[i].DataSize - hdr := &pkts[i].Header - - off = pkts[i].DataOffset - if off != nextOff { - // We stop in a different point last time. - size := packetSize - viewIdx = 0 - viewOff = 0 - for size > 0 { - if size >= len(views[viewIdx]) { - viewIdx++ - viewOff = 0 - size -= len(views[viewIdx]) - } else { - viewOff = size - size = 0 + var vnetHdrBuf []byte + vnetHdr := virtioNetHdr{} + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if gso != nil { + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + if gso.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen + vnetHdr.csumOffset = gso.CsumOffset + } + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { + switch gso.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + } + vnetHdr.gsoSize = gso.MSS } } + vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr) + iovLen++ } - nextOff = off + packetSize + iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovecs[0] + iovecIdx := 0 + if vnetHdrBuf != nil { + v := &iovecs[iovecIdx] + v.Base = &vnetHdrBuf[0] + v.Len = uint64(len(vnetHdrBuf)) + iovecIdx++ + } if ethHdrBuf != nil { - v := &iovec[iovecIdx] + v := &iovecs[iovecIdx] v.Base = ðHdrBuf[0] v.Len = uint64(len(ethHdrBuf)) iovecIdx++ } - - v := &iovec[iovecIdx] + pktSize := uint64(0) + // Encode L3 Header + v := &iovecs[iovecIdx] + hdr := &pkt.Header hdrView := hdr.View() v.Base = &hdrView[0] v.Len = uint64(len(hdrView)) + pktSize += v.Len iovecIdx++ - for packetSize > 0 { - vec := &iovec[iovecIdx] + // Now encode the Transport Payload. + pktViews := pkt.Data.Views() + for i := range pktViews { + vec := &iovecs[iovecIdx] iovecIdx++ - - v := views[viewIdx] - vec.Base = &v[viewOff] - s := len(v) - viewOff - if s <= packetSize { - viewIdx++ - viewOff = 0 - } else { - s = packetSize - viewOff += s - } - vec.Len = uint64(s) - packetSize -= s + vec.Base = &pktViews[i][0] + vec.Len = uint64(len(pktViews[i])) + pktSize += vec.Len } - - mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + i++ } packets := 0 for packets < n { - fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))] + fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))] sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) if err != nil { return packets, err diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 4039753b7..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index f5973066d..a5478ce17 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,7 +87,7 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6461d0108..0796d717e 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0a6b8945c..062388f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -200,7 +200,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, pkt.Header.View(), gso) } @@ -233,20 +233,16 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket(gso, protocol, pkt) + e.dumpPacket(gso, protocol, &pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - view := pkts[0].Data.ToView() - for _, pkt := range pkts { - e.dumpPacket(gso, protocol, stack.PacketBuffer{ - Header: pkt.Header, - Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), - }) +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.dumpPacket(gso, protocol, pkt) } return e.lower.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 52fe397bf..2b3741276 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -112,9 +112,9 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { - return len(pkts), nil + return pkts.Len(), nil } n, err := e.lower.WritePackets(r, gso, pkts, protocol) diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 88224e494..54eb5322b 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -71,9 +71,9 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcp } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - e.writeCount += len(pkts) - return len(pkts), nil +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += pkts.Len() + return pkts.Len(), nil } func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 255098372..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4950d69fc..4c20301c6 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a7d9a8b25..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -280,28 +280,47 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil + } + + for pkt := pkts.Front(); pkt != nil; { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + pkt = pkt.Next() } // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - for i := range pkts { - if ok := ipt.Check(stack.Output, pkts[i]); !ok { - // iptables is telling us to drop the packet. + dropped := ipt.CheckPackets(stack.Output, pkts) + if len(dropped) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow Path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { continue } - ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) - pkts[i].NetworkHeader = buffer.View(ip) + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacket writes a packet already containing a network diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6d2d2c034..f91180aa3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -79,7 +79,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Only the first view in vv is accounted for by h. To account for the // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. - payload := pkt.Data + payload := pkt.Data.Clone(nil) payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b462b8604..a815b4d9b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -143,19 +143,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil } - for i := range pkts { - hdr := &pkts[i].Header - size := pkts[i].DataSize - ip := e.addIPHeader(r, hdr, size, params) - pkts[i].NetworkHeader = buffer.View(ip) + for pb := pkts.Front(); pb != nil; pb = pb.Next() { + ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) + pb.NetworkHeader = buffer.View(ip) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 8d80e9cee..5e963a4af 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -15,6 +15,18 @@ go_template_instance( }, ) +go_template_instance( + name = "packet_buffer_list", + out = "packet_buffer_list.go", + package = "stack", + prefix = "PacketBuffer", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*PacketBuffer", + "Linker": "*PacketBuffer", + }, +) + go_library( name = "stack", srcs = [ @@ -29,7 +41,7 @@ go_library( "ndp.go", "nic.go", "packet_buffer.go", - "packet_buffer_state.go", + "packet_buffer_list.go", "rand.go", "registration.go", "route.go", diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c45c43d21..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } @@ -260,10 +260,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 - for _, pkt := range pkts { - e.WritePacket(r, gso, protocol, pkt) + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.WritePacket(r, gso, protocol, *pkt) n++ } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 37907ae24..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -209,6 +209,23 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { return true } +// CheckPackets runs pkts through the rules for hook and returns a map of packets that +// should not go forward. +// +// NOTE: unlike the Check API the returned map contains packets that should be +// dropped. +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if ok := it.Check(hook, *pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + } + return drop +} + // Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 598468bdd..27dc8baf9 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -468,7 +468,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View(), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9367de180..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -23,9 +23,11 @@ import ( // As a PacketBuffer traverses up the stack, it may be necessary to pass it to // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. -// -// +stateify savable type PacketBuffer struct { + // PacketBufferEntry is used to build an intrusive list of + // PacketBuffers. + PacketBufferEntry + // Data holds the payload of the packet. For inbound packets, it also // holds the headers, which are consumed as the packet moves up the // stack. Headers are guaranteed not to be split across views. @@ -34,14 +36,6 @@ type PacketBuffer struct { // or otherwise modified. Data buffer.VectorisedView - // DataOffset is used for GSO output. It is the offset into the Data - // field where the payload of this packet starts. - DataOffset int - - // DataSize is used for GSO output. It is the size of this packet's - // payload. - DataSize int - // Header holds the headers of outbound packets. As a packet is passed // down the stack, each layer adds to Header. Header buffer.Prependable diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go deleted file mode 100644 index 0c6b7924c..000000000 --- a/pkg/tcpip/stack/packet_buffer_state.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// beforeSave is invoked by stateify. -func (pk *PacketBuffer) beforeSave() { - // Non-Data fields may be slices of the Data field. This causes - // problems for SR, so during save we make each header independent. - pk.Header = pk.Header.DeepCopy() - pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) - pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) - pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) -} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ac043b722..23ca9ee03 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -246,7 +246,7 @@ type NetworkEndpoint interface { // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. @@ -393,7 +393,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9fbe8a411..a0e5e0300 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -168,23 +168,26 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuff return err } -// WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +// WritePackets writes a list of n packets through the given route and returns +// the number of packets written. +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - payloadSize := 0 - for i := 0; i < n; i++ { - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) - payloadSize += pkts[i].DataSize + + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Header.UsedLength() + writtenBytes += pb.Data.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b8543b71e..3f8a2a095 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -153,7 +153,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3239a5911..2ca3fb809 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -756,8 +756,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) hdr := &pkt.Header - packetSize := pkt.DataSize - off := pkt.DataOffset + packetSize := pkt.Data.Size() // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) pkt.TransportHeader = buffer.View(tcp) @@ -782,12 +781,18 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) + xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } } func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + // We need to shallow clone the VectorisedView here as ReadToView will + // split the VectorisedView and Trim underlying views as it splits. Not + // doing the clone here will cause the underlying views of data itself + // to be altered. + data = data.Clone(nil) + optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff @@ -796,31 +801,25 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso mss := int(gso.MSS) n := (data.Size() + mss - 1) / mss - // Allocate one big slice for all the headers. - hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen - buf := make([]byte, n*hdrSize) - pkts := make([]stack.PacketBuffer, n) - for i := range pkts { - pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) - } - size := data.Size() - off := 0 + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + var pkts stack.PacketBufferList for i := 0; i < n; i++ { packetSize := mss if packetSize > size { packetSize = size } size -= packetSize - pkts[i].DataOffset = off - pkts[i].DataSize = packetSize - pkts[i].Data = data - pkts[i].Hash = tf.txHash - pkts[i].Owner = owner - buildTCPHdr(r, tf, &pkts[i], gso) - off += packetSize + var pkt stack.PacketBuffer + pkt.Header = buffer.NewPrependable(hdrSize) + pkt.Hash = tf.txHash + pkt.Owner = owner + data.ReadToVV(&pkt.Data, packetSize) + buildTCPHdr(r, tf, &pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkts.PushBack(&pkt) } + if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } @@ -845,12 +844,10 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac } pkt := stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - DataOffset: 0, - DataSize: data.Size(), - Data: data, - Hash: tf.txHash, - Owner: owner, + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + Data: data, + Hash: tf.txHash, + Owner: owner, } buildTCPHdr(r, tf, &pkt, gso) diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index e6fe7985d..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -77,9 +77,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V id: id, route: r.Clone(), } - s.views[0] = v - s.data = buffer.NewVectorisedView(len(v), s.views[:1]) s.rcvdTime = time.Now() + if len(v) != 0 { + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + } return s } -- cgit v1.2.3 From a551add5d8a5bf631cd9859c761e579fdb33ec82 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 13 Apr 2020 17:37:21 -0700 Subject: Remove View.First() and View.RemoveFirst() These methods let users eaily break the VectorisedView abstraction, and allowed netstack to slip into pseudo-enforcement of the "all headers are in the first View" invariant. Removing them and replacing with PullUp(n) breaks this reliance and will make it easier to add iptables support and rework network buffer management. The new View.PullUp(n) method is low cost in the common case, when when all the headers fit in the first View. --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 4 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 25 files changed, 395 insertions(+), 147 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7acbfa0a8..cf73a939e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,7 +93,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 104aafbed..17202cc7a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,7 +328,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -378,7 +382,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331b0817b..486725131 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index e9c652042..c7c663498 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -473,7 +476,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +520,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +567,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +622,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 016dbe15e..0c2b1f36a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,22 +1289,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1332,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1375,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index dc125f25e..e954a8b7e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,9 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c7634ceb1..d45d2cc1f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ab1014c7f..286c66cf5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 120d3b50f4875824ec69f0cc39a09ac84fced35c Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 21 Apr 2020 07:15:25 -0700 Subject: Automated rollback of changelist 307477185 PiperOrigin-RevId: 307598974 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++---------- pkg/tcpip/buffer/view_test.go | 113 ----------------------------- pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 ++++------------- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 ++--- pkg/tcpip/network/ipv4/ipv4.go | 12 +-- pkg/tcpip/network/ipv6/icmp.go | 74 +++++++------------ pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +----- pkg/tcpip/stack/iptables_targets.go | 23 ++---- pkg/tcpip/stack/nic.go | 34 ++++++--- pkg/tcpip/stack/packet_buffer.go | 4 +- pkg/tcpip/stack/stack_test.go | 10 +-- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++----- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 +-- 25 files changed, 147 insertions(+), 395 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,13 +121,12 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.TCPMinimumSize { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,13 +120,12 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,8 +77,7 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. It panics -// if count > vv.Size(). +// TrimFront removes the first "count" bytes of the vectorised view. func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -87,7 +86,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } } @@ -105,7 +104,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } if copied == 0 { return 0, io.EOF @@ -127,7 +126,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } return copied } @@ -163,37 +162,22 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// PullUp returns the first "count" bytes of the vectorised view. If those -// bytes aren't already contiguous inside the vectorised view, PullUp will -// reallocate as needed to make them contiguous. PullUp fails and returns false -// when count > vv.Size(). -func (vv *VectorisedView) PullUp(count int) (View, bool) { +// First returns the first view of the vectorised view. +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { - return nil, count == 0 - } - if count <= len(vv.views[0]) { - return vv.views[0][:count], true - } - if count > vv.size { - return nil, false + return nil } + return vv.views[0] +} - newFirst := NewView(count) - i := 0 - for offset := 0; offset < count; i++ { - copy(newFirst[offset:], vv.views[i]) - if count-offset < len(vv.views[i]) { - vv.views[i].TrimFront(count - offset) - break - } - offset += len(vv.views[i]) - vv.views[i] = nil +// RemoveFirst removes the first view of the vectorised view. +func (vv *VectorisedView) RemoveFirst() { + if len(vv.views) == 0 { + return } - // We're guaranteed that i > 0, since count is too large for the first - // view. - vv.views[i-1] = newFirst - vv.views = vv.views[i-1:] - return newFirst, true + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -241,10 +225,3 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } - -// removeFirst panics when len(vv.views) < 1. -func (vv *VectorisedView) removeFirst() { - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] -} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,7 +16,6 @@ package buffer import ( - "bytes" "reflect" "testing" ) @@ -371,115 +370,3 @@ func TestVVRead(t *testing.T) { }) } } - -var pullUpTestCases = []struct { - comment string - in VectorisedView - count int - want []byte - result VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - linkHeader := header.Ethernet(hdr) + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].vv.First()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0799c8f4d..be2537a82 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,7 +171,11 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + first := pkt.Header.View() + if len(first) == 0 { + first = pkt.Data.First() + } + logPacket(prefix, protocol, first, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -234,7 +238,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -243,49 +247,28 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P size := uint16(0) var fragmentOffset uint16 var moreFragments bool - - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) - switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - ipv4 := header.IPv4(hdr) + ipv4 := header.IPv4(b) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) + b = b[ipv4.HeaderLength():] id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - ipv6 := header.IPv6(hdr) + ipv6 := header.IPv6(b) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + b = b[header.IPv6MinimumSize:] case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { - return - } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + arp := header.ARP(b) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -301,7 +284,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -314,11 +297,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) - if !ok { - break - } - icmp := header.ICMPv4(hdr) + icmp := header.ICMPv4(b) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -351,11 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) - if !ok { - break - } - icmp := header.ICMPv6(hdr) + icmp := header.ICMPv6(b) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -386,11 +361,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { - break - } - udp := header.UDP(hdr) + udp := header.UDP(b) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -400,11 +371,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { - break - } - tcp := header.TCP(hdr) + tcp := header.TCP(b) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cf73a939e..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,10 +93,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,11 +25,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - hdr := header.IPv4(h) + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -38,12 +34,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } - hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + hlen := int(h.HeaderLength()) + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -52,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 17202cc7a..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,11 +328,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return tcpip.ErrInvalidOptionValue - } - ip := header.IPv4(h) + ip := header.IPv4(pkt.Data.First()) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -382,11 +378,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b68983d10 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,11 +28,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - hdr := header.IPv6(h) + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,21 +36,17 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) - if !ok { - return - } - fragHdr := header.IPv6Fragment(f) - if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { + f := header.IPv6Fragment(pkt.Data.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -63,19 +55,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } @@ -84,9 +76,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) + payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -107,40 +101,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { + if len(v) < header.ICMPv6PacketTooBigMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { + if len(v) < header.ICMPv6DstUnreachableMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + switch h.Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + ns := header.NDPNeighborSolicit(h.NDPPayload()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -298,16 +286,12 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + na := header.NDPNeighborAdvert(h.NDPPayload()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -379,15 +363,14 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) - if !ok { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, icmpHdr) + copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -401,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -423,9 +406,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + p := h.NDPPayload() + if len(p) < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -443,11 +425,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(p) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,8 +166,7 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, + size: header.ICMPv6NeighborSolicitMinimumSize}, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 486725131..331b0817b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,11 +171,7 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c7c663498..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,10 +70,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -476,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -520,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -567,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -622,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,11 +212,6 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. -// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -231,9 +226,7 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -278,21 +271,14 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pk.NetworkHeader is set. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. + // pkt.Data.First(). if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } + pkt.NetworkHeader = pkt.Data.First() } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..7b4543caf 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,12 +96,9 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 - } + headerView := newPkt.Data.First() netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -120,14 +117,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { return RuleDrop, 0 } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(newPkt.Data.First()) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -135,14 +128,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.TCPMinimumSize { + if len(pkt.Data.First()) < header.TCPMinimumSize { return RuleDrop, 0 } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(newPkt.TransportHeader) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0c2b1f36a..016dbe15e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + + src, dst := netProto.ParseAddresses(pkt.Data.First()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,8 +1289,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + + firstData := pkt.Data.First() + pkt.Data.RemoveFirst() + + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1318,13 +1332,12 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,12 +1375,11 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) - if !ok { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index e954a8b7e..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,9 +37,7 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. + // down the stack, each layer adds to Header. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d45d2cc1f..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,18 +95,16 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { + nb := pkt.Data.First() + if len(nb) < fakeNetHeaderLen { return } + pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,11 +642,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..feef8dca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.Data.First()) + if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.Data.First()) + if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7712ce652..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,11 +144,7 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) + h := header.TCP(s.data.First()) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -160,16 +156,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(h[header.TCPMinimumSize:offset]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -181,19 +173,18 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() + s.csum = h.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) + xsum = h.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 286c66cf5..ab1014c7f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..edb54f0be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,13 +68,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From eccae0f77d3708d591119488f427eca90de7c711 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 23 Apr 2020 17:27:24 -0700 Subject: Remove View.First() and View.RemoveFirst() These methods let users eaily break the VectorisedView abstraction, and allowed netstack to slip into pseudo-enforcement of the "all headers are in the first View" invariant. Removing them and replacing with PullUp(n) breaks this reliance and will make it easier to add iptables support and rework network buffer management. The new View.PullUp(n) method is low cost in the common case, when when all the headers fit in the first View. PiperOrigin-RevId: 308163542 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/rawfile/BUILD | 9 ++- pkg/tcpip/link/rawfile/rawfile_test.go | 46 ++++++++++++ pkg/tcpip/link/rawfile/rawfile_unsafe.go | 6 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 28 files changed, 458 insertions(+), 149 deletions(-) create mode 100644 pkg/tcpip/link/rawfile/rawfile_test.go (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 14b527bc2..9cc08d0e2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -18,3 +18,10 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "rawfile_test", + size = "small", + srcs = ["rawfile_test.go"], + library = ":rawfile", +) diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go new file mode 100644 index 000000000..8f14ba761 --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_test.go @@ -0,0 +1,46 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "testing" +) + +func TestNonBlockingWrite3ZeroLength(t *testing.T) { + fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + t.Fatalf("failed to open /dev/null: %v", err) + } + defer syscall.Close(fd) + + if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil { + t.Fatalf("failed to write: %v", err) + } +} + +func TestNonBlockingWrite3Nil(t *testing.T) { + fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + t.Fatalf("failed to open /dev/null: %v", err) + } + defer syscall.Close(fd) + + if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil { + t.Fatalf("failed to write: %v", err) + } +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 44e25d475..92efd0bf8 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -76,9 +76,13 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { // We have two buffers. Build the iovec that represents them and issue // a writev syscall. + var base *byte + if len(b1) > 0 { + base = &b1[0] + } iovec := [3]syscall.Iovec{ { - Base: &b1[0], + Base: base, Len: uint64(len(b1)), }, { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7acbfa0a8..cf73a939e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,7 +93,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 104aafbed..17202cc7a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,7 +328,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -378,7 +382,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331b0817b..486725131 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index e9c652042..c7c663498 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -473,7 +476,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +520,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +567,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +622,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 016dbe15e..0c2b1f36a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,22 +1289,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1332,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1375,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index dc125f25e..7d36f8e84 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,13 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. + // + // TODO(gvisor.dev/issue/170): Forwarded packets don't currently + // populate Header, but should. This will be doable once early parsing + // (https://github.com/google/gvisor/pull/1995) is supported. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c7634ceb1..d45d2cc1f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ab1014c7f..286c66cf5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 55f0c3316af8ea2a1fcc16511efc580f307623f6 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Mon, 27 Apr 2020 12:25:10 -0700 Subject: Automated rollback of changelist 308163542 PiperOrigin-RevId: 308674219 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++---------- pkg/tcpip/buffer/view_test.go | 113 ----------------------------- pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/rawfile/BUILD | 9 +-- pkg/tcpip/link/rawfile/rawfile_test.go | 46 ------------ pkg/tcpip/link/rawfile/rawfile_unsafe.go | 6 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 ++++------------- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 ++--- pkg/tcpip/network/ipv4/ipv4.go | 12 +-- pkg/tcpip/network/ipv6/icmp.go | 74 +++++++------------ pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +----- pkg/tcpip/stack/iptables_targets.go | 23 ++---- pkg/tcpip/stack/nic.go | 34 ++++++--- pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 +-- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++----- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 +-- 28 files changed, 149 insertions(+), 458 deletions(-) delete mode 100644 pkg/tcpip/link/rawfile/rawfile_test.go (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,13 +121,12 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.TCPMinimumSize { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,13 +120,12 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,8 +77,7 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. It panics -// if count > vv.Size(). +// TrimFront removes the first "count" bytes of the vectorised view. func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -87,7 +86,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } } @@ -105,7 +104,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } if copied == 0 { return 0, io.EOF @@ -127,7 +126,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } return copied } @@ -163,37 +162,22 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// PullUp returns the first "count" bytes of the vectorised view. If those -// bytes aren't already contiguous inside the vectorised view, PullUp will -// reallocate as needed to make them contiguous. PullUp fails and returns false -// when count > vv.Size(). -func (vv *VectorisedView) PullUp(count int) (View, bool) { +// First returns the first view of the vectorised view. +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { - return nil, count == 0 - } - if count <= len(vv.views[0]) { - return vv.views[0][:count], true - } - if count > vv.size { - return nil, false + return nil } + return vv.views[0] +} - newFirst := NewView(count) - i := 0 - for offset := 0; offset < count; i++ { - copy(newFirst[offset:], vv.views[i]) - if count-offset < len(vv.views[i]) { - vv.views[i].TrimFront(count - offset) - break - } - offset += len(vv.views[i]) - vv.views[i] = nil +// RemoveFirst removes the first view of the vectorised view. +func (vv *VectorisedView) RemoveFirst() { + if len(vv.views) == 0 { + return } - // We're guaranteed that i > 0, since count is too large for the first - // view. - vv.views[i-1] = newFirst - vv.views = vv.views[i-1:] - return newFirst, true + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -241,10 +225,3 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } - -// removeFirst panics when len(vv.views) < 1. -func (vv *VectorisedView) removeFirst() { - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] -} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,7 +16,6 @@ package buffer import ( - "bytes" "reflect" "testing" ) @@ -371,115 +370,3 @@ func TestVVRead(t *testing.T) { }) } } - -var pullUpTestCases = []struct { - comment string - in VectorisedView - count int - want []byte - result VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - linkHeader := header.Ethernet(hdr) + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 9cc08d0e2..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -18,10 +18,3 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) - -go_test( - name = "rawfile_test", - size = "small", - srcs = ["rawfile_test.go"], - library = ":rawfile", -) diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go deleted file mode 100644 index 8f14ba761..000000000 --- a/pkg/tcpip/link/rawfile/rawfile_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package rawfile - -import ( - "syscall" - "testing" -) - -func TestNonBlockingWrite3ZeroLength(t *testing.T) { - fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) - if err != nil { - t.Fatalf("failed to open /dev/null: %v", err) - } - defer syscall.Close(fd) - - if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil { - t.Fatalf("failed to write: %v", err) - } -} - -func TestNonBlockingWrite3Nil(t *testing.T) { - fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) - if err != nil { - t.Fatalf("failed to open /dev/null: %v", err) - } - defer syscall.Close(fd) - - if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil { - t.Fatalf("failed to write: %v", err) - } -} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 92efd0bf8..44e25d475 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -76,13 +76,9 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { // We have two buffers. Build the iovec that represents them and issue // a writev syscall. - var base *byte - if len(b1) > 0 { - base = &b1[0] - } iovec := [3]syscall.Iovec{ { - Base: base, + Base: &b1[0], Len: uint64(len(b1)), }, { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].vv.First()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0799c8f4d..be2537a82 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,7 +171,11 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + first := pkt.Header.View() + if len(first) == 0 { + first = pkt.Data.First() + } + logPacket(prefix, protocol, first, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -234,7 +238,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -243,49 +247,28 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P size := uint16(0) var fragmentOffset uint16 var moreFragments bool - - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) - switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - ipv4 := header.IPv4(hdr) + ipv4 := header.IPv4(b) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) + b = b[ipv4.HeaderLength():] id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - ipv6 := header.IPv6(hdr) + ipv6 := header.IPv6(b) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + b = b[header.IPv6MinimumSize:] case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { - return - } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + arp := header.ARP(b) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -301,7 +284,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -314,11 +297,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) - if !ok { - break - } - icmp := header.ICMPv4(hdr) + icmp := header.ICMPv4(b) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -351,11 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) - if !ok { - break - } - icmp := header.ICMPv6(hdr) + icmp := header.ICMPv6(b) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -386,11 +361,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { - break - } - udp := header.UDP(hdr) + udp := header.UDP(b) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -400,11 +371,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { - break - } - tcp := header.TCP(hdr) + tcp := header.TCP(b) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cf73a939e..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,10 +93,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,11 +25,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - hdr := header.IPv4(h) + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -38,12 +34,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } - hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + hlen := int(h.HeaderLength()) + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -52,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 17202cc7a..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,11 +328,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return tcpip.ErrInvalidOptionValue - } - ip := header.IPv4(h) + ip := header.IPv4(pkt.Data.First()) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -382,11 +378,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b68983d10 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,11 +28,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - hdr := header.IPv6(h) + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,21 +36,17 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) - if !ok { - return - } - fragHdr := header.IPv6Fragment(f) - if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { + f := header.IPv6Fragment(pkt.Data.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -63,19 +55,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } @@ -84,9 +76,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) + payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -107,40 +101,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { + if len(v) < header.ICMPv6PacketTooBigMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { + if len(v) < header.ICMPv6DstUnreachableMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + switch h.Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + ns := header.NDPNeighborSolicit(h.NDPPayload()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -298,16 +286,12 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + na := header.NDPNeighborAdvert(h.NDPPayload()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -379,15 +363,14 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) - if !ok { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, icmpHdr) + copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -401,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -423,9 +406,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + p := h.NDPPayload() + if len(p) < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -443,11 +425,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(p) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,8 +166,7 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, + size: header.ICMPv6NeighborSolicitMinimumSize}, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 486725131..331b0817b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,11 +171,7 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c7c663498..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,10 +70,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -476,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -520,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -567,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -622,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,11 +212,6 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. -// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -231,9 +226,7 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -278,21 +271,14 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pk.NetworkHeader is set. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. + // pkt.Data.First(). if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } + pkt.NetworkHeader = pkt.Data.First() } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..7b4543caf 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,12 +96,9 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 - } + headerView := newPkt.Data.First() netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -120,14 +117,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { return RuleDrop, 0 } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(newPkt.Data.First()) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -135,14 +128,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.TCPMinimumSize { + if len(pkt.Data.First()) < header.TCPMinimumSize { return RuleDrop, 0 } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(newPkt.TransportHeader) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0c2b1f36a..016dbe15e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + + src, dst := netProto.ParseAddresses(pkt.Data.First()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,8 +1289,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + + firstData := pkt.Data.First() + pkt.Data.RemoveFirst() + + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1318,13 +1332,12 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,12 +1375,11 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) - if !ok { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7d36f8e84..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,13 +37,7 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. + // down the stack, each layer adds to Header. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d45d2cc1f..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,18 +95,16 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { + nb := pkt.Data.First() + if len(nb) < fakeNetHeaderLen { return } + pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,11 +642,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..feef8dca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.Data.First()) + if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.Data.First()) + if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7712ce652..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,11 +144,7 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) + h := header.TCP(s.data.First()) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -160,16 +156,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(h[header.TCPMinimumSize:offset]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -181,19 +173,18 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() + s.csum = h.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) + xsum = h.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 286c66cf5..ab1014c7f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..edb54f0be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,13 +68,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 5e1e61fbcbe8fa3cc8b104fadb8cdef3ad29c31f Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Fri, 1 May 2020 16:08:26 -0700 Subject: Automated rollback of changelist 308674219 PiperOrigin-RevId: 309491861 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/fdbased/endpoint.go | 3 + pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 26 files changed, 402 insertions(+), 147 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 53a9712c6..affa1bbdf 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -436,6 +436,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne if pkt.Data.Size() == 0 { return rawfile.NonBlockingWrite(fd, pkt.Header.View()) } + if pkt.Header.UsedLength() == 0 { + return rawfile.NonBlockingWrite(fd, pkt.Data.ToView()) + } return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9f47b4ff2..9d0797af7 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,7 +99,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a9dec0c0e..1d61fddad 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -333,7 +333,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -383,7 +387,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 82928fb66..daf1fcbc6 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 9bc97b84e..8084d50bc 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -477,7 +480,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -521,7 +524,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -568,7 +571,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -623,7 +626,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 440970a21..7b54919bb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1212,12 +1212,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1298,22 +1298,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1341,12 +1327,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1384,11 +1371,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 06d312207..9ff80ab24 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,13 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. + // + // TODO(gvisor.dev/issue/170): Forwarded packets don't currently + // populate Header, but should. This will be doable once early parsing + // (https://github.com/google/gvisor/pull/1995) is supported. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3f4e08434..1a2cf007c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 7e574859b..33e2b9a09 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3556,7 +3556,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3583,7 +3583,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From b660f16d18827f0310594c80d9387de11430f15f Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Fri, 27 Mar 2020 12:18:45 -0700 Subject: Support for connection tracking of TCP packets. Connection tracking is used to track packets in prerouting and output hooks of iptables. The NAT rules modify the tuples in connections. The connection tracking code modifies the packets by looking at the modified tuples. --- pkg/sentry/socket/netfilter/netfilter.go | 14 +- pkg/sentry/socket/netfilter/targets.go | 3 +- pkg/sentry/socket/netfilter/tcp_matcher.go | 17 +- pkg/sentry/socket/netfilter/udp_matcher.go | 17 +- pkg/tcpip/header/tcp.go | 17 + pkg/tcpip/network/ipv4/ipv4.go | 46 ++- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/conntrack.go | 480 ++++++++++++++++++++++ pkg/tcpip/stack/iptables.go | 51 ++- pkg/tcpip/stack/iptables_targets.go | 115 +++--- pkg/tcpip/stack/iptables_types.go | 4 +- pkg/tcpip/stack/nic.go | 4 +- pkg/tcpip/stack/packet_buffer.go | 4 + pkg/tcpip/stack/route.go | 13 + pkg/tcpip/stack/stack.go | 19 + pkg/tcpip/transport/tcp/BUILD | 2 +- pkg/tcpip/transport/tcp/rcv.go | 19 +- pkg/tcpip/transport/tcp/rcv_test.go | 6 +- pkg/tcpip/transport/tcpconntrack/BUILD | 1 - pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go | 16 +- test/iptables/filter_output.go | 2 +- test/iptables/iptables_test.go | 66 ++- test/iptables/iptables_util.go | 2 +- test/iptables/nat.go | 103 ++++- 24 files changed, 858 insertions(+), 165 deletions(-) create mode 100644 pkg/tcpip/stack/conntrack.go (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 72d093aa8..40736fb38 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -251,7 +251,7 @@ func marshalTarget(target stack.Target) []byte { case stack.ReturnTarget: return marshalStandardTarget(stack.RuleReturn) case stack.RedirectTarget: - return marshalRedirectTarget() + return marshalRedirectTarget(tg) case JumpTarget: return marshalJumpTarget(tg) default: @@ -288,7 +288,7 @@ func marshalErrorTarget(errorName string) []byte { return binary.Marshal(ret, usermem.ByteOrder, target) } -func marshalRedirectTarget() []byte { +func marshalRedirectTarget(rt stack.RedirectTarget) []byte { // This is a redirect target named redirect target := linux.XTRedirectTarget{ Target: linux.XTEntryTarget{ @@ -298,6 +298,16 @@ func marshalRedirectTarget() []byte { copy(target.Target.Name[:], redirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + target.NfRange.RangeSize = 1 + if rt.RangeProtoSpecified { + target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + } + // Convert port from little endian to big endian. + port := make([]byte, 2) + binary.LittleEndian.PutUint16(port, rt.MinPort) + target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) + binary.LittleEndian.PutUint16(port, rt.MaxPort) + target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) return binary.Marshal(ret, usermem.ByteOrder, target) } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index c948de876..84abe8d29 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,6 +15,7 @@ package netfilter import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -29,6 +30,6 @@ type JumpTarget struct { } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) { +func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..57a1e1c12 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -120,14 +120,27 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa if pkt.TransportHeader != nil { tcpHeader = header.TCP(pkt.TransportHeader) } else { + var length int + if hook == stack.Prerouting { + // The network header hasn't been parsed yet. We have to do it here. + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // There's no valid TCP header here, so we hotdrop the + // packet. + return false, true + } + h := header.IPv4(hdr) + pkt.NetworkHeader = hdr + length = int(h.HeaderLength()) + } // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize) if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(hdr[length:]) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..cfa9e621d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -119,14 +119,27 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa if pkt.TransportHeader != nil { udpHeader = header.UDP(pkt.TransportHeader) } else { + var length int + if hook == stack.Prerouting { + // The network header hasn't been parsed yet. We have to do it here. + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // There's no valid UDP header here, so we hotdrop the + // packet. + return false, true + } + h := header.IPv4(hdr) + pkt.NetworkHeader = hdr + length = int(h.HeaderLength()) + } // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize) if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(hdr[length:]) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 13480687d..21581257b 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -594,3 +594,20 @@ func AddTCPOptionPadding(options []byte, offset int) int { } return paddingToAdd } + +// Acceptable checks if a segment that starts at segSeq and has length segLen is +// "acceptable" for arriving in a receive window that starts at rcvNxt and ends +// before rcvAcc, according to the table on page 26 and 69 of RFC 793. +func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { + if rcvNxt == rcvAcc { + return segLen == 0 && segSeq == rcvNxt + } + if segLen == 0 { + // rcvWnd is incremented by 1 because that is Linux's behavior despite the + // RFC. + return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) + } + // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming + // the payload, so we'll accept any payload that overlaps the receieve window. + return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1d61fddad..9db42b2a4 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -252,11 +252,31 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt); !ok { + if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok { // iptables is telling us to drop the packet. return nil } + if pkt.NatDone { + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + netHeader := header.IPv4(pkt.NetworkHeader) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + packet := stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} + ep.HandlePacket(&route, packet) + return nil + } + } + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. @@ -302,8 +322,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - dropped := ipt.CheckPackets(stack.Output, pkts) - if len(dropped) == 0 { + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r) + if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -318,6 +338,24 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := dropped[pkt]; ok { continue } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv4(pkt.NetworkHeader) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + packet := stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} + ep.HandlePacket(&route, packet) + n++ + continue + } + } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err @@ -407,7 +445,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt); !ok { + if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 5e963a4af..f71073207 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -30,6 +30,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "conntrack.go", "dhcpv6configurationfromndpra_string.go", "forwarder.go", "icmp_rate_limit.go", @@ -62,6 +63,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", + "//pkg/tcpip/transport/tcpconntrack", "//pkg/waiter", "@org_golang_x_time//rate:go_default_library", ], diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go new file mode 100644 index 000000000..7d1ede1f2 --- /dev/null +++ b/pkg/tcpip/stack/conntrack.go @@ -0,0 +1,480 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" +) + +// Connection tracking is used to track and manipulate packets for NAT rules. +// The connection is created for a packet if it does not exist. Every connection +// contains two tuples (original and reply). The tuples are manipulated if there +// is a matching NAT rule. The packet is modified by looking at the tuples in the +// Prerouting and Output hooks. + +// Direction of the tuple. +type ctDirection int + +const ( + dirOriginal ctDirection = iota + dirReply +) + +// Status of connection. +// TODO(gvisor.dev/issue/170): Add other states of connection. +type connStatus int + +const ( + connNew connStatus = iota + connEstablished +) + +// Manipulation type for the connection. +type manipType int + +const ( + manipDstPrerouting manipType = iota + manipDstOutput +) + +// connTrackMutable is the manipulatable part of the tuple. +type connTrackMutable struct { + // addr is source address of the tuple. + addr tcpip.Address + + // port is source port of the tuple. + port uint16 + + // protocol is network layer protocol. + protocol tcpip.NetworkProtocolNumber +} + +// connTrackImmutable is the non-manipulatable part of the tuple. +type connTrackImmutable struct { + // addr is destination address of the tuple. + addr tcpip.Address + + // direction is direction (original or reply) of the tuple. + direction ctDirection + + // port is destination port of the tuple. + port uint16 + + // protocol is transport layer protocol. + protocol tcpip.TransportProtocolNumber +} + +// connTrackTuple represents the tuple which is created from the +// packet. +type connTrackTuple struct { + // dst is non-manipulatable part of the tuple. + dst connTrackImmutable + + // src is manipulatable part of the tuple. + src connTrackMutable +} + +// connTrackTupleHolder is the container of tuple and connection. +type ConnTrackTupleHolder struct { + // conn is pointer to the connection tracking entry. + conn *connTrack + + // tuple is original or reply tuple. + tuple connTrackTuple +} + +// connTrack is the connection. +type connTrack struct { + // originalTupleHolder contains tuple in original direction. + originalTupleHolder ConnTrackTupleHolder + + // replyTupleHolder contains tuple in reply direction. + replyTupleHolder ConnTrackTupleHolder + + // status indicates connection is new or established. + status connStatus + + // timeout indicates the time connection should be active. + timeout time.Duration + + // manip indicates if the packet should be manipulated. + manip manipType + + // tcb is TCB control block. It is used to keep track of states + // of tcp connection. + tcb tcpconntrack.TCB + + // tcbHook indicates if the packet is inbound or outbound to + // update the state of tcb. + tcbHook Hook +} + +// ConnTrackTable contains a map of all existing connections created for +// NAT rules. +type ConnTrackTable struct { + // connMu protects connTrackTable. + connMu sync.RWMutex + + // connTrackTable maintains a map of tuples needed for connection tracking + // for iptables NAT rules. The key for the map is an integer calculated + // using seed, source address, destination address, source port and + // destination port. + CtMap map[uint32]ConnTrackTupleHolder + + // seed is a one-time random value initialized at stack startup + // and is used in calculation of hash key for connection tracking + // table. + Seed uint32 +} + +// parseHeaders sets headers in the packet. +func parseHeaders(pkt *PacketBuffer) { + newPkt := pkt.Clone() + + // Set network header. + hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + netHeader := header.IPv4(hdr) + newPkt.NetworkHeader = hdr + length := int(netHeader.HeaderLength()) + + // TODO(gvisor.dev/issue/170): Need to support for other + // protocols as well. + // Set transport header. + switch protocol := netHeader.TransportProtocol(); protocol { + case header.UDPProtocolNumber: + if newPkt.TransportHeader == nil { + h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize) + if !ok { + return + } + newPkt.TransportHeader = buffer.View(header.UDP(h[length:])) + } + case header.TCPProtocolNumber: + if newPkt.TransportHeader == nil { + h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize) + if !ok { + return + } + newPkt.TransportHeader = buffer.View(header.TCP(h[length:])) + } + } + pkt.NetworkHeader = newPkt.NetworkHeader + pkt.TransportHeader = newPkt.TransportHeader +} + +// packetToTuple converts packet to a tuple in original direction. +func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { + var tuple connTrackTuple + + netHeader := header.IPv4(pkt.NetworkHeader) + // TODO(gvisor.dev/issue/170): Need to support for other + // protocols as well. + if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + return tuple, tcpip.ErrUnknownProtocol + } + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return tuple, tcpip.ErrUnknownProtocol + } + + tuple.src.addr = netHeader.SourceAddress() + tuple.src.port = tcpHeader.SourcePort() + tuple.src.protocol = header.IPv4ProtocolNumber + + tuple.dst.addr = netHeader.DestinationAddress() + tuple.dst.port = tcpHeader.DestinationPort() + tuple.dst.protocol = netHeader.TransportProtocol() + + return tuple, nil +} + +// getReplyTuple creates reply tuple for the given tuple. +func getReplyTuple(tuple connTrackTuple) connTrackTuple { + var replyTuple connTrackTuple + replyTuple.src.addr = tuple.dst.addr + replyTuple.src.port = tuple.dst.port + replyTuple.src.protocol = tuple.src.protocol + replyTuple.dst.addr = tuple.src.addr + replyTuple.dst.port = tuple.src.port + replyTuple.dst.protocol = tuple.dst.protocol + replyTuple.dst.direction = dirReply + + return replyTuple +} + +// makeNewConn creates new connection. +func makeNewConn(tuple, replyTuple connTrackTuple) connTrack { + var conn connTrack + conn.status = connNew + conn.originalTupleHolder.tuple = tuple + conn.originalTupleHolder.conn = &conn + conn.replyTupleHolder.tuple = replyTuple + conn.replyTupleHolder.conn = &conn + + return conn +} + +// getTupleHash returns hash of the tuple. The fields used for +// generating hash are seed (generated once for stack), source address, +// destination address, source port and destination ports. +func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { + h := jenkins.Sum32(ct.Seed) + h.Write([]byte(tuple.src.addr)) + h.Write([]byte(tuple.dst.addr)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, tuple.src.port) + h.Write([]byte(portBuf)) + binary.LittleEndian.PutUint16(portBuf, tuple.dst.port) + h.Write([]byte(portBuf)) + + return h.Sum32() +} + +// connTrackForPacket returns connTrack for packet. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other +// transport protocols. +func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { + if hook == Prerouting { + // Headers will not be set in Prerouting. + // TODO(gvisor.dev/issue/170): Change this after parsing headers + // code is added. + parseHeaders(pkt) + } + + var dir ctDirection + tuple, err := packetToTuple(*pkt, hook) + if err != nil { + return nil, dir + } + + ct.connMu.Lock() + defer ct.connMu.Unlock() + + connTrackTable := ct.CtMap + hash := ct.getTupleHash(tuple) + + var conn *connTrack + switch createConn { + case true: + // If connection does not exist for the hash, create a new + // connection. + replyTuple := getReplyTuple(tuple) + replyHash := ct.getTupleHash(replyTuple) + newConn := makeNewConn(tuple, replyTuple) + conn = &newConn + + // Add tupleHolders to the map. + // TODO(gvisor.dev/issue/170): Need to support collisions using linked list. + ct.CtMap[hash] = conn.originalTupleHolder + ct.CtMap[replyHash] = conn.replyTupleHolder + default: + tupleHolder, ok := connTrackTable[hash] + if !ok { + return nil, dir + } + + // If this is the reply of new connection, set the connection + // status as ESTABLISHED. + conn = tupleHolder.conn + if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply { + conn.status = connEstablished + } + if tupleHolder.conn == nil { + panic("tupleHolder has null connection tracking entry") + } + + dir = tupleHolder.tuple.dst.direction + } + return conn, dir +} + +// SetNatInfo will manipulate the tuples according to iptables NAT rules. +func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) { + // Get the connection. Connection is always created before this + // function is called. + conn, _ := ct.connTrackForPacket(pkt, hook, false) + if conn == nil { + panic("connection should be created to manipulate tuples.") + } + replyTuple := conn.replyTupleHolder.tuple + replyHash := ct.getTupleHash(replyTuple) + + // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to + // support changing of address for Prerouting. + + // Change the port as per the iptables rule. This tuple will be used + // to manipulate the packet in HandlePacket. + conn.replyTupleHolder.tuple.src.addr = rt.MinIP + conn.replyTupleHolder.tuple.src.port = rt.MinPort + newHash := ct.getTupleHash(conn.replyTupleHolder.tuple) + + // Add the changed tuple to the map. + ct.connMu.Lock() + defer ct.connMu.Unlock() + ct.CtMap[newHash] = conn.replyTupleHolder + if hook == Output { + conn.replyTupleHolder.conn.manip = manipDstOutput + } + + // Delete the old tuple. + delete(ct.CtMap, replyHash) +} + +// handlePacketPrerouting manipulates ports for packets in Prerouting hook. +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.. +func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) { + netHeader := header.IPv4(pkt.NetworkHeader) + tcpHeader := header.TCP(pkt.TransportHeader) + + // For prerouting redirection, packets going in the original direction + // have their destinations modified and replies have their sources + // modified. + switch dir { + case dirOriginal: + port := conn.replyTupleHolder.tuple.src.port + tcpHeader.SetDestinationPort(port) + netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + case dirReply: + port := conn.originalTupleHolder.tuple.dst.port + tcpHeader.SetSourcePort(port) + netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + } + + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) +} + +// handlePacketOutput manipulates ports for packets in Output hook. +func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) { + netHeader := header.IPv4(pkt.NetworkHeader) + tcpHeader := header.TCP(pkt.TransportHeader) + + // For output redirection, packets going in the original direction + // have their destinations modified and replies have their sources + // modified. For prerouting redirection, we only reach this point + // when replying, so packet sources are modified. + if conn.manip == manipDstOutput && dir == dirOriginal { + port := conn.replyTupleHolder.tuple.src.port + tcpHeader.SetDestinationPort(port) + netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + } else { + port := conn.originalTupleHolder.tuple.dst.port + tcpHeader.SetSourcePort(port) + netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + } + + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + hdr := &pkt.Header + length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + if gso != nil && gso.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) +} + +// HandlePacket will manipulate the port and address of the packet if the +// connection exists. +func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { + if pkt.NatDone { + return + } + + if hook != Prerouting && hook != Output { + return + } + + conn, dir := ct.connTrackForPacket(pkt, hook, false) + // Connection or Rule not found for the packet. + if conn == nil { + return + } + + netHeader := header.IPv4(pkt.NetworkHeader) + // TODO(gvisor.dev/issue/170): Need to support for other transport + // protocols as well. + if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return + } + + switch hook { + case Prerouting: + handlePacketPrerouting(pkt, conn, dir) + case Output: + handlePacketOutput(pkt, conn, gso, r, dir) + } + pkt.NatDone = true + + // Update the state of tcb. + // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle + // other tcp states. + var st tcpconntrack.Result + if conn.tcb.IsEmpty() { + conn.tcb.Init(tcpHeader) + conn.tcbHook = hook + } else { + switch hook { + case conn.tcbHook: + st = conn.tcb.UpdateStateOutbound(tcpHeader) + default: + st = conn.tcb.UpdateStateInbound(tcpHeader) + } + } + + // Delete conntrack if tcp connection is closed. + if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { + ct.deleteConnTrack(conn) + } +} + +// deleteConnTrack deletes the connection. +func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) { + if conn == nil { + return + } + + tuple := conn.originalTupleHolder.tuple + hash := ct.getTupleHash(tuple) + replyTuple := conn.replyTupleHolder.tuple + replyHash := ct.getTupleHash(replyTuple) + + ct.connMu.Lock() + defer ct.connMu.Unlock() + + delete(ct.CtMap, hash) + delete(ct.CtMap, replyHash) +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..7c3c47d50 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -17,6 +17,7 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -110,6 +111,10 @@ func DefaultTables() IPTables { Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, + connections: ConnTrackTable{ + CtMap: make(map[uint32]ConnTrackTupleHolder), + Seed: generateRandUint32(), + }, } } @@ -173,12 +178,16 @@ const ( // dropped. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address) bool { + // Packets are manipulated only if connection and matching + // NAT rule exists. + it.connections.HandlePacket(pkt, hook, gso, r) + // Go through each table containing the hook. for _, tablename := range it.Priorities[hook] { table := it.Tables[tablename] ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -189,7 +198,7 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { case RuleAccept: continue case RuleDrop: @@ -219,26 +228,34 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if ok := it.Check(hook, *pkt); !ok { - if drop == nil { - drop = make(map[*PacketBuffer]struct{}) + if !pkt.NatDone { + if ok := it.Check(hook, pkt, gso, r, ""); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + if pkt.NatDone { + if natPkts == nil { + natPkts = make(map[*PacketBuffer]struct{}) + } + natPkts[pkt] = struct{}{} } - drop[pkt] = struct{}{} } } - return drop + return drop, natPkts } // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address); verdict { case RuleAccept: return chainAccept @@ -255,7 +272,7 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address); verdict { case chainAccept: return chainAccept case chainDrop: @@ -279,9 +296,9 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx } // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in @@ -304,7 +321,7 @@ func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, *pkt, "") if hotdrop { return RuleDrop, 0 } @@ -315,7 +332,7 @@ func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..36cc6275d 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -24,7 +24,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -32,7 +32,7 @@ func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -41,7 +41,7 @@ func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -52,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { +func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -61,7 +61,7 @@ func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) { +func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -75,16 +75,16 @@ type RedirectTarget struct { // redirect. RangeProtoSpecified bool - // Min address used to redirect. + // MinIP indicates address used to redirect. MinIP tcpip.Address - // Max address used to redirect. + // MaxIP indicates address used to redirect. MaxIP tcpip.Address - // Min port used to redirect. + // MinPort indicates port used to redirect. MinPort uint16 - // Max port used to redirect. + // MaxPort indicates port used to redirect. MaxPort uint16 } @@ -92,61 +92,76 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { - newPkt := pkt.Clone() +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Packet is already manipulated. + if pkt.NatDone { + return RuleAccept, 0 + } // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 + if hook == Prerouting { + parseHeaders(pkt) } - netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView - hlen := int(netHeader.HeaderLength()) - tlen := int(netHeader.TotalLength()) - newPkt.Data.TrimFront(hlen) - newPkt.Data.CapLength(tlen - hlen) + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + return RuleDrop, 0 + } - // TODO(gvisor.dev/issue/170): Change destination address to - // loopback or interface address on which the packet was - // received. + // Change the address to localhost (127.0.0.1) in Output and + // to primary address of the incoming interface in Prerouting. + switch hook { + case Output: + rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1}) + rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1}) + case Prerouting: + rt.MinIP = address + rt.MaxIP = address + default: + panic("redirect target is supported only on output and prerouting hooks") + } // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. + netHeader := header.IPv4(pkt.NetworkHeader) switch protocol := netHeader.TransportProtocol(); protocol { case header.UDPProtocolNumber: - var udpHeader header.UDP - if newPkt.TransportHeader != nil { - udpHeader = header.UDP(newPkt.TransportHeader) - } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - return RuleDrop, 0 + udpHeader := header.UDP(pkt.TransportHeader) + udpHeader.SetDestinationPort(rt.MinPort) + + // Calculate UDP checksum and set it. + if hook == Output { + udpHeader.SetChecksum(0) + hdr := &pkt.Header + length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + xsum := r.PseudoHeaderChecksum(protocol, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHeader.SetChecksum(0) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } - udpHeader = header.UDP(hdr) } - udpHeader.SetDestinationPort(rt.MinPort) + // Change destination address. + netHeader.SetDestinationAddress(rt.MinIP) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + pkt.NatDone = true case header.TCPProtocolNumber: - var tcpHeader header.TCP - if newPkt.TransportHeader != nil { - tcpHeader = header.TCP(newPkt.TransportHeader) - } else { - if pkt.Data.Size() < header.TCPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + if ct == nil { + return RuleAccept, 0 + } + + // Set up conection for matching NAT rule. + // Only the first packet of the connection comes here. + // Other packets will be manipulated in connection tracking. + if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil { + ct.SetNatInfo(pkt, rt, hook) + ct.HandlePacket(pkt, hook, gso, r) } - // TODO(gvisor.dev/issue/170): Need to recompute checksum - // and implement nat connection tracking to support TCP. - tcpHeader.SetDestinationPort(rt.MinPort) default: return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 2ffb55f2a..1bb0ba1bd 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -82,6 +82,8 @@ type IPTables struct { // list is the order in which each table should be visited for that // hook. Priorities map[Hook][]string + + connections ConnTrackTable } // A Table defines a set of chains and hooks into the network stack. It is @@ -176,5 +178,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet PacketBuffer) (RuleVerdict, int) + Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 7b54919bb..8f4c1fe42 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1230,8 +1230,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. if protocol == header.IPv4ProtocolNumber { + // iptables filtering. ipt := n.stack.IPTables() - if ok := ipt.Check(Prerouting, pkt); !ok { + address := n.primaryAddress(protocol) + if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9ff80ab24..926df4d7b 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -72,6 +72,10 @@ type PacketBuffer struct { EgressRoute *Route GSOOptions *GSO NetworkProtocolNumber tcpip.NetworkProtocolNumber + + // NatDone indicates if the packet has been manipulated as per NAT + // iptables rule. + NatDone bool } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 53148dc03..150297ab9 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -261,3 +261,16 @@ func (r *Route) MakeLoopedRoute() Route { func (r *Route) Stack() *Stack { return r.ref.stack() } + +// ReverseRoute returns new route with given source and destination address. +func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { + return Route{ + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + ref: r.ref, + Loop: r.Loop, + } +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 4a2dc3dc6..e33fae4eb 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1885,3 +1885,22 @@ func generateRandInt64() int64 { } return v } + +// FindNetworkEndpoint returns the network endpoint for the given address. +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, nic := range s.nics { + id := NetworkEndpointID{address} + + if ref, ok := nic.mu.endpoints[id]; ok { + nic.mu.RLock() + defer nic.mu.RUnlock() + + // An endpoint with this id exists, check if it can be used and return it. + return ref.ep, nil + } + } + return nil, tcpip.ErrBadAddress +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f2aa69069..f38eb6833 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -115,7 +115,7 @@ go_test( size = "small", srcs = ["rcv_test.go"], deps = [ - ":tcp", + "//pkg/tcpip/header", "//pkg/tcpip/seqnum", ], ) diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index a4b73b588..6fe97fefd 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -70,24 +70,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale // acceptable checks if the segment sequence number range is acceptable // according to the table on page 26 of RFC 793. func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - return Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) -} - -// Acceptable checks if a segment that starts at segSeq and has length segLen is -// "acceptable" for arriving in a receive window that starts at rcvNxt and ends -// before rcvAcc, according to the table on page 26 and 69 of RFC 793. -func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { - if rcvNxt == rcvAcc { - return segLen == 0 && segSeq == rcvNxt - } - if segLen == 0 { - // rcvWnd is incremented by 1 because that is Linux's behavior despite the - // RFC. - return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) - } - // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming - // the payload, so we'll accept any payload that overlaps the receieve window. - return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) + return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) } // getSendParams returns the parameters needed by the sender when building diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index dc02729ce..c9eeff935 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -17,8 +17,8 @@ package rcv_test import ( "testing" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) func TestAcceptable(t *testing.T) { @@ -67,8 +67,8 @@ func TestAcceptable(t *testing.T) { {105, 2, 108, 108, false}, {105, 2, 109, 109, false}, } { - if got := tcp.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { - t.Errorf("tcp.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) + if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { + t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) } } } diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 2025ff757..3ad6994a7 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -9,7 +9,6 @@ go_library( deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", - "//pkg/tcpip/transport/tcp", ], ) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 30d05200f..12bc1b5b5 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -20,7 +20,6 @@ package tcpconntrack import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) // Result is returned when the state of a TCB is updated in response to an @@ -312,7 +311,7 @@ type stream struct { // the window is zero, if it's a packet with no payload and sequence number // equal to una. func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - return tcp.Acceptable(segSeq, segLen, s.una, s.end) + return header.Acceptable(segSeq, segLen, s.una, s.end) } // closed determines if the stream has already been closed. This happens when @@ -338,3 +337,16 @@ func logicalLen(tcp header.TCP) seqnum.Size { } return l } + +// IsEmpty returns true if tcb is not initialized. +func (t *TCB) IsEmpty() bool { + if t.inbound != (stream{}) || t.outbound != (stream{}) { + return false + } + + if t.firstFin != nil || t.state != ResultDrop { + return false + } + + return true +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index f6d974b85..b1382d353 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -42,7 +42,7 @@ func (FilterOutputDropTCPDestPort) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { return err } diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 334d8e676..63a862d35 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -115,30 +115,6 @@ func TestFilterInputDropOnlyUDP(t *testing.T) { singleTest(t, FilterInputDropOnlyUDP{}) } -func TestNATRedirectUDPPort(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - singleTest(t, NATRedirectUDPPort{}) -} - -func TestNATRedirectTCPPort(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - singleTest(t, NATRedirectTCPPort{}) -} - -func TestNATDropUDP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - singleTest(t, NATDropUDP{}) -} - -func TestNATAcceptAll(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - singleTest(t, NATAcceptAll{}) -} - func TestFilterInputDropTCPDestPort(t *testing.T) { singleTest(t, FilterInputDropTCPDestPort{}) } @@ -164,14 +140,10 @@ func TestFilterInputReturnUnderflow(t *testing.T) { } func TestFilterOutputDropTCPDestPort(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, FilterOutputDropTCPDestPort{}) } func TestFilterOutputDropTCPSrcPort(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, FilterOutputDropTCPSrcPort{}) } @@ -235,44 +207,54 @@ func TestOutputInvertDestination(t *testing.T) { singleTest(t, FilterOutputInvertDestination{}) } +func TestNATPreRedirectUDPPort(t *testing.T) { + singleTest(t, NATPreRedirectUDPPort{}) +} + +func TestNATPreRedirectTCPPort(t *testing.T) { + singleTest(t, NATPreRedirectTCPPort{}) +} + +func TestNATOutRedirectUDPPort(t *testing.T) { + singleTest(t, NATOutRedirectUDPPort{}) +} + +func TestNATOutRedirectTCPPort(t *testing.T) { + singleTest(t, NATOutRedirectTCPPort{}) +} + +func TestNATDropUDP(t *testing.T) { + singleTest(t, NATDropUDP{}) +} + +func TestNATAcceptAll(t *testing.T) { + singleTest(t, NATAcceptAll{}) +} + func TestNATOutRedirectIP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATOutRedirectIP{}) } func TestNATOutDontRedirectIP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATOutDontRedirectIP{}) } func TestNATOutRedirectInvert(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATOutRedirectInvert{}) } func TestNATPreRedirectIP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATPreRedirectIP{}) } func TestNATPreDontRedirectIP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATPreDontRedirectIP{}) } func TestNATPreRedirectInvert(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATPreRedirectInvert{}) } func TestNATRedirectRequiresProtocol(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATRedirectRequiresProtocol{}) } diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 2a00677be..2f988cd18 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -151,7 +151,7 @@ func connectTCP(ip net.IP, port int, timeout time.Duration) error { return err } if err := testutil.Poll(callback, timeout); err != nil { - return fmt.Errorf("timed out waiting to connect IP, most recent error: %v", err) + return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) } return nil diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 40096901c..0a10ce7fe 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -26,8 +26,10 @@ const ( ) func init() { - RegisterTestCase(NATRedirectUDPPort{}) - RegisterTestCase(NATRedirectTCPPort{}) + RegisterTestCase(NATPreRedirectUDPPort{}) + RegisterTestCase(NATPreRedirectTCPPort{}) + RegisterTestCase(NATOutRedirectUDPPort{}) + RegisterTestCase(NATOutRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) RegisterTestCase(NATAcceptAll{}) RegisterTestCase(NATPreRedirectIP{}) @@ -39,16 +41,16 @@ func init() { RegisterTestCase(NATRedirectRequiresProtocol{}) } -// NATRedirectUDPPort tests that packets are redirected to different port. -type NATRedirectUDPPort struct{} +// NATPreRedirectUDPPort tests that packets are redirected to different port. +type NATPreRedirectUDPPort struct{} // Name implements TestCase.Name. -func (NATRedirectUDPPort) Name() string { - return "NATRedirectUDPPort" +func (NATPreRedirectUDPPort) Name() string { + return "NATPreRedirectUDPPort" } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectUDPPort) ContainerAction(ip net.IP) error { +func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error { if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } @@ -61,33 +63,53 @@ func (NATRedirectUDPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATRedirectUDPPort) LocalAction(ip net.IP) error { +func (NATPreRedirectUDPPort) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } -// NATRedirectTCPPort tests that connections are redirected on specified ports. -type NATRedirectTCPPort struct{} +// NATPreRedirectTCPPort tests that connections are redirected on specified ports. +type NATPreRedirectTCPPort struct{} // Name implements TestCase.Name. -func (NATRedirectTCPPort) Name() string { - return "NATRedirectTCPPort" +func (NATPreRedirectTCPPort) Name() string { + return "NATPreRedirectTCPPort" } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectTCPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { +func (NATPreRedirectTCPPort) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } // Listen for TCP packets on redirect port. - return listenTCP(redirectPort, sendloopDuration) + return listenTCP(acceptPort, sendloopDuration) } // LocalAction implements TestCase.LocalAction. -func (NATRedirectTCPPort) LocalAction(ip net.IP) error { +func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error { return connectTCP(ip, dropPort, sendloopDuration) } +// NATOutRedirectUDPPort tests that packets are redirected to different port. +type NATOutRedirectUDPPort struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectUDPPort) Name() string { + return "NATOutRedirectUDPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectUDPPort) ContainerAction(ip net.IP) error { + dest := []byte{200, 0, 0, 1} + return loopbackTest(dest, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectUDPPort) LocalAction(ip net.IP) error { + // No-op. + return nil +} + // NATDropUDP tests that packets are not received in ports other than redirect // port. type NATDropUDP struct{} @@ -329,3 +351,52 @@ func loopbackTest(dest net.IP, args ...string) error { // sendCh will always take the full sendloop time. return <-sendCh } + +// NATOutRedirectTCPPort tests that connections are redirected on specified ports. +type NATOutRedirectTCPPort struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectTCPPort) Name() string { + return "NATOutRedirectTCPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error { + if err := natTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + timeout := 20 * time.Second + dest := []byte{127, 0, 0, 1} + localAddr := net.TCPAddr{ + IP: dest, + Port: acceptPort, + } + + // Starts listening on port. + lConn, err := net.ListenTCP("tcp", &localAddr) + if err != nil { + return err + } + defer lConn.Close() + + // Accept connections on port. + lConn.SetDeadline(time.Now().Add(timeout)) + err = connectTCP(ip, dropPort, timeout) + if err != nil { + return err + } + + conn, err := lConn.AcceptTCP() + if err != nil { + return err + } + conn.Close() + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error { + return nil +} -- cgit v1.2.3 From cfd30665c1d857f20dd05e67c6da6833770e2141 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 8 May 2020 15:39:04 -0700 Subject: iptables - filter packets using outgoing interface. Enables commands with -o (--out-interface) for iptables rules. $ iptables -A OUTPUT -o eth0 -j ACCEPT PiperOrigin-RevId: 310642286 --- pkg/abi/linux/netfilter.go | 2 +- pkg/sentry/socket/netfilter/netfilter.go | 40 ++++++-- pkg/tcpip/network/ipv4/ipv4.go | 8 +- pkg/tcpip/stack/iptables.go | 42 ++++++-- pkg/tcpip/stack/iptables_types.go | 13 +++ pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 16 ++- test/iptables/filter_output.go | 170 +++++++++++++++++++++++++++++++ test/iptables/iptables_test.go | 24 +++++ test/iptables/iptables_util.go | 15 +++ 10 files changed, 309 insertions(+), 23 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index a8d4f9d69..46d8b0b42 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -146,7 +146,7 @@ type IPTIP struct { // OutputInterface is the output network interface. OutputInterface [IFNAMSIZ]byte - // InputInterfaceMask is the intput interface mask. + // InputInterfaceMask is the input interface mask. InputInterfaceMask [IFNAMSIZ]byte // OuputInterfaceMask is the output interface mask. diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 40736fb38..41a1ce031 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,6 +17,7 @@ package netfilter import ( + "bytes" "errors" "fmt" @@ -204,6 +205,16 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI TargetOffset: linux.SizeOfIPTEntry, }, } + copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) + copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + if rule.Filter.DstInvert { + entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + } + if rule.Filter.OutputInterfaceInvert { + entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + } for _, matcher := range rule.Matchers { // Serialize the matcher and add it to the @@ -719,11 +730,27 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } + + n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) + if n == -1 { + n = len(iptip.OutputInterface) + } + ifname := string(iptip.OutputInterface[:n]) + + n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) + if n == -1 { + n = len(iptip.OutputInterfaceMask) + } + ifnameMask := string(iptip.OutputInterfaceMask[:n]) + return stack.IPHeaderFilter{ - Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), - Dst: tcpip.Address(iptip.Dst[:]), - DstMask: tcpip.Address(iptip.DstMask[:]), - DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + OutputInterface: ifname, + OutputInterfaceMask: ifnameMask, + OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, }, nil } @@ -732,16 +759,15 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { // - Protocol // - Dst and DstMask // - The inverse destination IP check flag + // - OutputInterface, OutputInterfaceMask and its inverse. var emptyInetAddr = linux.InetAddr{} var emptyInterface = [linux.IFNAMSIZ]byte{} // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT) return iptip.Src != emptyInetAddr || iptip.SrcMask != emptyInetAddr || iptip.InputInterface != emptyInterface || - iptip.OutputInterface != emptyInterface || iptip.InputInterfaceMask != emptyInterface || - iptip.OutputInterfaceMask != emptyInterface || iptip.Flags != 0 || iptip.InverseFlags&^inverseMask != 0 } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9db42b2a4..64046cbbf 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -249,10 +249,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok { + if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. return nil } @@ -319,10 +320,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkt = pkt.Next() } + nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r) + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -445,7 +447,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok { + if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 7c3c47d50..443423b3c 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -178,7 +179,7 @@ const ( // dropped. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { // Packets are manipulated only if connection and matching // NAT rule exists. it.connections.HandlePacket(pkt, hook, gso, r) @@ -187,7 +188,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr for _, tablename := range it.Priorities[hook] { table := it.Tables[tablename] ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -228,10 +229,10 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, ""); !ok { + if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -251,11 +252,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { case RuleAccept: return chainAccept @@ -272,7 +273,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -298,7 +299,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in @@ -313,7 +314,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { + if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -335,7 +336,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } -func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { +func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool { // TODO(gvisor.dev/issue/170): Support other fields of the filter. // Check the transport protocol. if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { @@ -355,5 +356,26 @@ func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { return false } + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(filter.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := filter.OutputInterface + matches = true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return filter.OutputInterfaceInvert != matches + } + return true } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 1bb0ba1bd..fe06007ae 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -158,6 +158,19 @@ type IPHeaderFilter struct { // true the filter will match packets that fail the destination // comparison. DstInvert bool + + // OutputInterface matches the name of the outgoing interface for the + // packet. + OutputInterface string + + // OutputInterfaceMask masks the characters of the interface name when + // comparing with OutputInterface. + OutputInterfaceMask string + + // OutputInterfaceInvert inverts the meaning of outgoing interface check, + // i.e. when true the filter will match packets that fail the outgoing + // interface comparison. + OutputInterfaceInvert bool } // A Matcher is the interface for matching packets. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8f4c1fe42..54103fdb3 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1233,7 +1233,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address); !ok { + if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e33fae4eb..b39ffa9fb 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1898,9 +1898,23 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres nic.mu.RLock() defer nic.mu.RUnlock() - // An endpoint with this id exists, check if it can be used and return it. + // An endpoint with this id exists, check if it can be + // used and return it. return ref.ep, nil } } return nil, tcpip.ErrBadAddress } + +// FindNICNameFromID returns the name of the nic for the given NICID. +func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return "" + } + + return nic.Name() +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index b1382d353..c145bd1e9 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -29,6 +29,12 @@ func init() { RegisterTestCase(FilterOutputAcceptUDPOwner{}) RegisterTestCase(FilterOutputDropUDPOwner{}) RegisterTestCase(FilterOutputOwnerFail{}) + RegisterTestCase(FilterOutputInterfaceAccept{}) + RegisterTestCase(FilterOutputInterfaceDrop{}) + RegisterTestCase(FilterOutputInterface{}) + RegisterTestCase(FilterOutputInterfaceBeginsWith{}) + RegisterTestCase(FilterOutputInterfaceInvertDrop{}) + RegisterTestCase(FilterOutputInterfaceInvertAccept{}) } // FilterOutputDropTCPDestPort tests that connections are not accepted on @@ -286,3 +292,167 @@ func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error { func (FilterOutputInvertDestination) LocalAction(ip net.IP) error { return listenUDP(acceptPort, sendloopDuration) } + +// FilterOutputInterfaceAccept tests that packets are sent via interface +// matching the iptables rule. +type FilterOutputInterfaceAccept struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceAccept) Name() string { + return "FilterOutputInterfaceAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceAccept) LocalAction(ip net.IP) error { + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputInterfaceDrop tests that packets are not sent via interface +// matching the iptables rule. +type FilterOutputInterfaceDrop struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceDrop) Name() string { + return "FilterOutputInterfaceDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceDrop) LocalAction(ip net.IP) error { + if err := listenUDP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } + + return nil +} + +// FilterOutputInterface tests that packets are sent via interface which is +// not matching the interface name in the iptables rule. +type FilterOutputInterface struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterface) Name() string { + return "FilterOutputInterface" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterface) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterface) LocalAction(ip net.IP) error { + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputInterfaceBeginsWith tests that packets are not sent via an +// interface which begins with the given interface name. +type FilterOutputInterfaceBeginsWith struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceBeginsWith) Name() string { + return "FilterOutputInterfaceBeginsWith" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP) error { + if err := listenUDP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } + + return nil +} + +// FilterOutputInterfaceInvertDrop tests that we selectively do not send +// packets via interface not matching the interface name. +type FilterOutputInterfaceInvertDrop struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceInvertDrop) Name() string { + return "FilterOutputInterfaceInvertDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputInterfaceInvertAccept tests that we can selectively send packets +// not matching the specific outgoing interface. +type FilterOutputInterfaceInvertAccept struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceInvertAccept) Name() string { + return "FilterOutputInterfaceInvertAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + return listenTCP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP) error { + return connectTCP(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 63a862d35..84eb75a40 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -167,6 +167,30 @@ func TestFilterOutputOwnerFail(t *testing.T) { singleTest(t, FilterOutputOwnerFail{}) } +func TestFilterOutputInterfaceAccept(t *testing.T) { + singleTest(t, FilterOutputInterfaceAccept{}) +} + +func TestFilterOutputInterfaceDrop(t *testing.T) { + singleTest(t, FilterOutputInterfaceDrop{}) +} + +func TestFilterOutputInterface(t *testing.T) { + singleTest(t, FilterOutputInterface{}) +} + +func TestFilterOutputInterfaceBeginsWith(t *testing.T) { + singleTest(t, FilterOutputInterfaceBeginsWith{}) +} + +func TestFilterOutputInterfaceInvertDrop(t *testing.T) { + singleTest(t, FilterOutputInterfaceInvertDrop{}) +} + +func TestFilterOutputInterfaceInvertAccept(t *testing.T) { + singleTest(t, FilterOutputInterfaceInvertAccept{}) +} + func TestJumpSerialize(t *testing.T) { singleTest(t, FilterInputSerializeJump{}) } diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 2f988cd18..7146edbb9 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -169,3 +169,18 @@ func localAddrs() ([]string, error) { } return addrStrs, nil } + +// getInterfaceName returns the name of the interface other than loopback. +func getInterfaceName() (string, bool) { + var ifname string + if interfaces, err := net.Interfaces(); err == nil { + for _, intf := range interfaces { + if intf.Name != "lo" { + ifname = intf.Name + break + } + } + } + + return ifname, ifname != "" +} -- cgit v1.2.3 From c55b84e16aeb4481106661e3877c50edbf281762 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 28 May 2020 16:44:15 -0700 Subject: Enable iptables source filtering (-s/--source) --- pkg/sentry/socket/netfilter/netfilter.go | 21 ++++++++--- pkg/tcpip/stack/iptables.go | 47 +----------------------- pkg/tcpip/stack/iptables_types.go | 62 ++++++++++++++++++++++++++++++++ test/iptables/filter_input.go | 60 +++++++++++++++++++++++++++++++ test/iptables/iptables_test.go | 8 +++++ 5 files changed, 147 insertions(+), 51 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 789bb94c8..47ff48c00 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -64,6 +64,8 @@ const enableLogging = false var emptyFilter = stack.IPHeaderFilter{ Dst: "\x00\x00\x00\x00", DstMask: "\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00", } // nflog logs messages related to the writing and reading of iptables. @@ -214,11 +216,16 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI } copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) + copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP } + if rule.Filter.SrcInvert { + entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + } if rule.Filter.OutputInterfaceInvert { entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } @@ -737,6 +744,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } + if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) if n == -1 { @@ -755,6 +765,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Dst: tcpip.Address(iptip.Dst[:]), DstMask: tcpip.Address(iptip.DstMask[:]), DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, OutputInterface: ifname, OutputInterfaceMask: ifnameMask, OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, @@ -765,15 +778,13 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { // The following features are supported: // - Protocol // - Dst and DstMask + // - Src and SrcMask // - The inverse destination IP check flag // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInetAddr = linux.InetAddr{} var emptyInterface = [linux.IFNAMSIZ]byte{} // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.Src != emptyInetAddr || - iptip.SrcMask != emptyInetAddr || - iptip.InputInterface != emptyInterface || + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || iptip.InputInterfaceMask != emptyInterface || iptip.Flags != 0 || iptip.InverseFlags&^inverseMask != 0 diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 443423b3c..709ede3fa 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -314,7 +313,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -335,47 +334,3 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // All the matchers matched, so run the target. return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } - -func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. - // Check the transport protocol. - if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { - return false - } - - // Check the destination IP. - dest := hdr.DestinationAddress() - matches := true - for i := range filter.Dst { - if dest[i]&filter.DstMask[i] != filter.Dst[i] { - matches = false - break - } - } - if matches == filter.DstInvert { - return false - } - - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(filter.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which begins - // with the name should be matched. - ifName := filter.OutputInterface - matches = true - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } else { - matches = nicName == ifName - } - return filter.OutputInterfaceInvert != matches - } - - return true -} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index fe06007ae..a3bd3e700 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,7 +15,10 @@ package stack import ( + "strings" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // A Hook specifies one of the hooks built into the network stack. @@ -159,6 +162,16 @@ type IPHeaderFilter struct { // comparison. DstInvert bool + // Src matches the source IP address. + Src tcpip.Address + + // SrcMask masks bits of the source IP address when comparing with Src. + SrcMask tcpip.Address + + // SrcInvert inverts the meaning of the source IP check, i.e. when true the + // filter will match packets that fail the source comparison. + SrcInvert bool + // OutputInterface matches the name of the outgoing interface for the // packet. OutputInterface string @@ -173,6 +186,55 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } +// match returns whether hdr matches the filter. +func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the source and destination IPs. + if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + return false + } + + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(fl.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := fl.OutputInterface + matches := true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return fl.OutputInterfaceInvert != matches + } + + return true +} + +// filterAddress returns whether addr matches the filter. +func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { + matches := true + for i := range filterAddr { + if addr[i]&mask[i] != filterAddr[i] { + matches = false + break + } + } + return matches != invert +} + // A Matcher is the interface for matching packets. type Matcher interface { // Name returns the name of the Matcher. diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 41e0cfa8d..14e385f5a 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -49,6 +49,8 @@ func init() { RegisterTestCase(FilterInputJumpTwice{}) RegisterTestCase(FilterInputDestination{}) RegisterTestCase(FilterInputInvertDestination{}) + RegisterTestCase(FilterInputSource{}) + RegisterTestCase(FilterInputInvertSource{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -667,3 +669,61 @@ func (FilterInputInvertDestination) ContainerAction(ip net.IP) error { func (FilterInputInvertDestination) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } + +// FilterInputSource verifies that we can filter packets via `-d +// `. +type FilterInputSource struct{} + +// Name implements TestCase.Name. +func (FilterInputSource) Name() string { + return "FilterInputSource" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputSource) ContainerAction(ip net.IP) error { + // Make INPUT's default action DROP, then ACCEPT all packets from this + // machine. + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-A", "INPUT", "-s", fmt.Sprintf("%v", ip), "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputSource) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputInvertSource verifies that we can filter packets via `! -d +// `. +type FilterInputInvertSource struct{} + +// Name implements TestCase.Name. +func (FilterInputInvertSource) Name() string { + return "FilterInputInvertSource" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInvertSource) ContainerAction(ip net.IP) error { + // Make INPUT's default action DROP, then ACCEPT all packets not bound + // for 127.0.0.1. + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-A", "INPUT", "!", "-s", localIP, "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInvertSource) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 4fd2cb46a..172ad9e16 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -302,3 +302,11 @@ func TestNATPreRedirectInvert(t *testing.T) { func TestNATRedirectRequiresProtocol(t *testing.T) { singleTest(t, NATRedirectRequiresProtocol{}) } + +func TestInputSource(t *testing.T) { + singleTest(t, FilterInputSource{}) +} + +func TestInputInvertSource(t *testing.T) { + singleTest(t, FilterInputInvertSource{}) +} -- cgit v1.2.3 From d3a8bffe04595910714ec67231585bc33dab2b5b Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Wed, 3 Jun 2020 14:57:57 -0700 Subject: Pass PacketBuffer as pointer. Historically we've been passing PacketBuffer by shallow copying through out the stack. Right now, this is only correct as the caller would not use PacketBuffer after passing into the next layer in netstack. With new buffer management effort in gVisor/netstack, PacketBuffer will own a Buffer (to be added). Internally, both PacketBuffer and Buffer may have pointers and shallow copying shouldn't be used. Updates #2404. PiperOrigin-RevId: 314610879 --- pkg/sentry/socket/netfilter/owner_matcher.go | 2 +- pkg/sentry/socket/netfilter/tcp_matcher.go | 2 +- pkg/sentry/socket/netfilter/udp_matcher.go | 2 +- pkg/tcpip/link/channel/channel.go | 8 ++--- pkg/tcpip/link/fdbased/endpoint.go | 4 +-- pkg/tcpip/link/fdbased/endpoint_test.go | 10 +++---- pkg/tcpip/link/fdbased/mmap.go | 2 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 4 +-- pkg/tcpip/link/loopback/loopback.go | 6 ++-- pkg/tcpip/link/muxed/injectable.go | 4 +-- pkg/tcpip/link/muxed/injectable_test.go | 4 +-- pkg/tcpip/link/qdisc/fifo/endpoint.go | 6 ++-- pkg/tcpip/link/sharedmem/sharedmem.go | 4 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 26 ++++++++--------- pkg/tcpip/link/sniffer/sniffer.go | 8 ++--- pkg/tcpip/link/tun/device.go | 2 +- pkg/tcpip/link/waitable/waitable.go | 4 +-- pkg/tcpip/link/waitable/waitable_test.go | 16 +++++----- pkg/tcpip/network/arp/arp.go | 10 +++---- pkg/tcpip/network/arp/arp_test.go | 2 +- pkg/tcpip/network/ip_test.go | 30 ++++++++++++------- pkg/tcpip/network/ipv4/icmp.go | 12 +++++--- pkg/tcpip/network/ipv4/ipv4.go | 34 +++++++++++----------- pkg/tcpip/network/ipv4/ipv4_test.go | 28 +++++++++++------- pkg/tcpip/network/ipv6/icmp.go | 10 +++---- pkg/tcpip/network/ipv6/icmp_test.go | 14 ++++----- pkg/tcpip/network/ipv6/ipv6.go | 8 ++--- pkg/tcpip/network/ipv6/ipv6_test.go | 8 ++--- pkg/tcpip/network/ipv6/ndp_test.go | 10 +++---- pkg/tcpip/stack/conntrack.go | 4 +-- pkg/tcpip/stack/forwarder.go | 4 +-- pkg/tcpip/stack/forwarder_test.go | 34 +++++++++++----------- pkg/tcpip/stack/iptables.go | 2 +- pkg/tcpip/stack/iptables_types.go | 2 +- pkg/tcpip/stack/ndp.go | 4 +-- pkg/tcpip/stack/ndp_test.go | 14 ++++----- pkg/tcpip/stack/nic.go | 12 ++++---- pkg/tcpip/stack/nic_test.go | 2 +- pkg/tcpip/stack/packet_buffer.go | 33 +++++++++++++++++++-- pkg/tcpip/stack/registration.go | 26 ++++++++--------- pkg/tcpip/stack/route.go | 4 +-- pkg/tcpip/stack/stack.go | 4 +-- pkg/tcpip/stack/stack_test.go | 26 ++++++++--------- pkg/tcpip/stack/transport_demuxer.go | 14 ++++----- pkg/tcpip/stack/transport_demuxer_test.go | 4 +-- pkg/tcpip/stack/transport_test.go | 22 +++++++------- pkg/tcpip/transport/icmp/endpoint.go | 8 ++--- pkg/tcpip/transport/icmp/protocol.go | 2 +- pkg/tcpip/transport/packet/endpoint.go | 2 +- pkg/tcpip/transport/raw/endpoint.go | 6 ++-- pkg/tcpip/transport/tcp/connect.go | 4 +-- pkg/tcpip/transport/tcp/dispatcher.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 4 +-- pkg/tcpip/transport/tcp/forwarder.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 4 +-- pkg/tcpip/transport/tcp/segment.go | 2 +- pkg/tcpip/transport/tcp/testing/context/context.go | 10 +++---- pkg/tcpip/transport/udp/endpoint.go | 10 +++++-- pkg/tcpip/transport/udp/forwarder.go | 4 +-- pkg/tcpip/transport/udp/protocol.go | 6 ++-- pkg/tcpip/transport/udp/udp_test.go | 4 +-- 61 files changed, 306 insertions(+), 255 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 3863293c7..1b4e0ad79 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -111,7 +111,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 57a1e1c12..ebabdf334 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) if netHeader.TransportProtocol() != header.TCPProtocolNumber { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index cfa9e621d..98b9943f8 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 5eb78b398..20b183da0 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -181,12 +181,12 @@ func (e *Endpoint) NumQueued() int { } // InjectInbound injects an inbound packet. -func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) { +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } @@ -229,13 +229,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() p := PacketInfo{ - Pkt: &pkt, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 5ee508d48..f34082e1a 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -387,7 +387,7 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) @@ -641,7 +641,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 6f41a71a8..eaee7e5d7 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents stack.PacketBuffer + contents *stack.PacketBuffer } type context struct { @@ -103,7 +103,7 @@ func (c *context) cleanup() { } } -func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.ch <- packetInfo{remote, protocol, pkt} } @@ -179,7 +179,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), Hash: hash, @@ -295,7 +295,7 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buffer.VectorisedView{}, }); err != nil { @@ -358,7 +358,7 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: stack.PacketBuffer{ + contents: &stack.PacketBuffer{ Data: buffer.View(b).ToVectorisedView(), LinkHeader: buffer.View(hdr), }, diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index ca4229ed6..2dfd29aa9 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -191,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, stack.PacketBuffer{ + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ Data: buffer.View(pkt).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 26c96a655..f04738cfb 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), LinkHeader: buffer.View(eth), } @@ -296,7 +296,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), LinkHeader: buffer.View(eth), } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 20d9e95f6..568c6874f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index f0769830a..c69d6b7e9 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,7 +80,7 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } @@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts s // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 87c734c1f..0744f66d6 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) @@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewView(0).ToVectorisedView(), }) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index ec5c5048a..b5dfb7850 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -102,7 +102,7 @@ func (q *queueDispatcher) dispatchLoop() { } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. -func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -146,7 +146,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. newRoute := r.Clone() @@ -154,7 +154,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] - if !d.q.enqueue(&pkt) { + if !d.q.enqueue(pkt) { return tcpip.ErrNoBufferSpace } d.newPacketWaker.Assert() diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index f5dec0a7f..0374a2441 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Add the ethernet header here. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) @@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index f3fc62607..28a2e88ba 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, @@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, }); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }) @@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: uu, }); err != want { @@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index b060d4627..ae3186314 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -120,8 +120,8 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, &pkt) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dumpPacket("recv", nil, protocol, pkt) e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -208,8 +208,8 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, &pkt) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.dumpPacket("send", gso, protocol, pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 617446ea2..6bc9033d0 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -213,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.View(data).ToVectorisedView(), } if ethHdr != nil { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index f5a05929f..949b3f2b2 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,7 +50,7 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } @@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 0a9b99f18..63bf40562 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatchCount++ } @@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9d0797af7..ea1acba83 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -80,7 +80,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -94,11 +94,11 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { v, ok := pkt.Data.PullUp(header.ARPSize) if !ok { return @@ -122,7 +122,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) fallthrough // also fill the cache from requests @@ -177,7 +177,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 1646d9cde..66e67429c 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ Data: v.ToVectorisedView(), }) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4c20301c6..d9b62f2db 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } @@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) @@ -150,7 +150,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -246,7 +246,11 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -289,7 +293,7 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -379,7 +383,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: vv, }) if want := c.expectedCount; o.controlCalls != want { @@ -444,7 +448,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: frag1.ToVectorisedView(), }) if o.dataCalls != 0 { @@ -452,7 +456,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send second segment. - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: frag2.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -487,7 +491,11 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -530,7 +538,7 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -644,7 +652,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view[:len(view)-c.trunc].ToVectorisedView(), }) if want := c.expectedCount; o.controlCalls != want { diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..d1c3ae835 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -24,7 +24,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -56,7 +56,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) @@ -88,7 +88,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ Data: pkt.Data.Clone(nil), NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), }) @@ -102,7 +102,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: vv, TransportHeader: buffer.View(pkt), diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 64046cbbf..9cd7592f4 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -129,7 +129,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { // packet's stated length matches the length of the header+payload. mtu // includes the IP header and options. This does not support the DontFragment // IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() @@ -169,7 +169,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, if i > 0 { newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -188,7 +188,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -202,7 +202,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr := pkt.Header startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: startOfHdr, Data: emptyVV, NetworkHeader: buffer.View(h), @@ -245,7 +245,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -253,7 +253,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. return nil } @@ -271,9 +271,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + ep.HandlePacket(&route, &stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) return nil } } @@ -286,7 +286,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, stack.PacketBuffer{ + e.HandlePacket(&loopedR, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -351,14 +351,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + ep.HandlePacket(&route, &stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) n++ continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } @@ -370,7 +370,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) @@ -426,7 +426,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -447,7 +447,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 36035c820..c208ebd99 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -114,7 +114,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { +func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) @@ -174,7 +174,7 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn type errorChannel struct { *channel.Endpoint - Ch chan stack.PacketBuffer + Ch chan *stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -184,7 +184,7 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan stack.PacketBuffer, size), + Ch: make(chan *stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } @@ -203,7 +203,7 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { select { case e.Ch <- pkt: default: @@ -282,13 +282,17 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := stack.PacketBuffer{ + source := &stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -296,7 +300,7 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []stack.PacketBuffer + var results []*stack.PacketBuffer L: for { select { @@ -338,7 +342,11 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -460,7 +468,7 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), }) } @@ -698,7 +706,7 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b62fb1de6 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -70,7 +70,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived @@ -288,7 +288,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, }); err != nil { sent.Dropped.Increment() @@ -390,7 +390,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: pkt.Data, }); err != nil { @@ -532,7 +532,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..a720f626f 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -57,7 +57,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return nil } @@ -67,7 +67,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) { } type stubLinkAddressCache struct { @@ -189,7 +189,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -328,7 +328,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ Data: vv, }) } @@ -563,7 +563,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -740,7 +740,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -918,7 +918,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index daf1fcbc6..0d94ad122 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -116,7 +116,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -128,7 +128,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, stack.PacketBuffer{ + e.HandlePacket(&loopedR, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -163,14 +163,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 841a0cb7a..213ff64f2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,7 +65,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -123,7 +123,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -637,7 +637,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -1238,7 +1238,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 12b70f7e9..3c141b91b 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,7 +136,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -380,7 +380,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -497,7 +497,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -568,7 +568,7 @@ func TestNDPValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, stack.PacketBuffer{ + ep.HandlePacket(r, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -884,7 +884,7 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 7d1ede1f2..d4053be08 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -186,7 +186,7 @@ func parseHeaders(pkt *PacketBuffer) { } // packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { +func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { var tuple connTrackTuple netHeader := header.IPv4(pkt.NetworkHeader) @@ -265,7 +265,7 @@ func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, creat } var dir ctDirection - tuple, err := packetToTuple(*pkt, hook) + tuple, err := packetToTuple(pkt, hook) if err != nil { return nil, dir } diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 6b64cd37f..3eff141e6 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go @@ -32,7 +32,7 @@ type pendingPacket struct { nic *NIC route *Route proto tcpip.NetworkProtocolNumber - pkt PacketBuffer + pkt *PacketBuffer } type forwardQueue struct { @@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue { return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { shouldWait := false f.Lock() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 344d60baa..63537aaad 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -68,7 +68,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { return &f.id } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Consume the network header. b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) if !ok { @@ -96,7 +96,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) @@ -112,7 +112,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -190,7 +190,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress LocalLinkAddress tcpip.LinkAddress - Pkt PacketBuffer + Pkt *PacketBuffer } type fwdTestLinkEndpoint struct { @@ -203,12 +203,12 @@ type fwdTestLinkEndpoint struct { } // InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } @@ -251,7 +251,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -270,7 +270,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, *pkt) + e.WritePacket(r, gso, protocol, pkt) n++ } @@ -280,7 +280,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: PacketBuffer{Data: vv}, + Pkt: &PacketBuffer{Data: vv}, } select { @@ -362,7 +362,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -399,7 +399,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -430,7 +430,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -460,7 +460,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -468,7 +468,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // forwarded to NIC 2. buf = buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -510,7 +510,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -557,7 +557,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[0] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -610,7 +610,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 709ede3fa..d989dbe91 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -321,7 +321,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, *pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index a3bd3e700..af72b9c46 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -245,7 +245,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 526c7d6ff..ae7a8f740 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -750,7 +750,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -1881,7 +1881,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b3d174cdd..58f1ebf60 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -613,7 +613,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -935,7 +935,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -970,14 +970,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -986,7 +986,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -994,7 +994,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -1003,7 +1003,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 05646e5e2..ec8e3cb85 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1153,7 +1153,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1167,7 +1167,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1233,7 +1233,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { + if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } @@ -1298,7 +1298,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { pkt.Header = buffer.NewPrependable(linkHeaderLen) @@ -1318,7 +1318,7 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1365,7 +1365,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index b01b3f476..fea46158c 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -44,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 926df4d7b..1b5da6017 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -24,6 +24,8 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { + _ noCopy + // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry @@ -82,7 +84,32 @@ type PacketBuffer struct { // VectorisedView but does not deep copy the underlying bytes. // // Clone also does not deep copy any of its other fields. -func (pk PacketBuffer) Clone() PacketBuffer { - pk.Data = pk.Data.Clone(nil) - return pk +// +// FIXME(b/153685824): Data gets copied but not other header references. +func (pk *PacketBuffer) Clone() *PacketBuffer { + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + Header: pk.Header, + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + } } + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index db89234e8..94f177841 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,12 +67,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +100,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +118,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +150,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -180,7 +180,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +189,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -242,7 +242,7 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have already // been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and @@ -251,7 +251,7 @@ type NetworkEndpoint interface { // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. - WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -266,7 +266,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -327,7 +327,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -389,7 +389,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the // given route. pkts must not be zero length. It takes ownership of pkts and @@ -431,7 +431,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 3d0e5cc6e..f5b6ca0b9 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -153,7 +153,7 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } @@ -199,7 +199,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0ab4c3e19..8af06cb9a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -52,7 +52,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -778,7 +778,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1a2cf007c..f6ddc3ced 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -90,7 +90,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ @@ -132,7 +132,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -147,7 +147,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, stack.PacketBuffer{ + f.HandlePacket(r, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } @@ -163,7 +163,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -293,7 +293,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -305,7 +305,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -317,7 +317,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -328,7 +328,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -340,7 +340,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -362,7 +362,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -420,7 +420,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -2263,7 +2263,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2345,7 +2345,7 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9a33ed375..e09866405 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] @@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -251,7 +251,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -379,7 +379,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs @@ -470,7 +470,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -520,7 +520,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -544,7 +544,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 2474a7db3..67d778137 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -127,7 +127,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -165,7 +165,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..cb350ead3 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -88,7 +88,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -215,7 +215,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -232,7 +232,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -289,7 +289,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -369,7 +369,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -380,7 +380,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -391,7 +391,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -446,7 +446,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -457,7 +457,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -468,7 +468,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -623,7 +623,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: req.ToVectorisedView(), }) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..29ff68df3 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -450,7 +450,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), @@ -481,7 +481,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: dataVV, TransportHeader: buffer.View(icmpv6), @@ -743,7 +743,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -805,7 +805,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3c47692b2..2ec6749c7 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 23158173d..bab2d63ae 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -298,7 +298,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index eee754a5a..25a17940d 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -348,7 +348,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{ + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { return 0, nil, err @@ -357,7 +357,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), Owner: e.owner, @@ -584,7 +584,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index e4a06c9e1..7da93dcc4 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -833,13 +833,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), Data: data, Hash: tf.txHash, Owner: owner, } - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { tf.ttl = r.DefaultTTL() diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 6062ca916..047704c80 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -186,7 +186,7 @@ func (d *dispatcher) wait() { } } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) if !s.parse() { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b5ba972f1..d048ef90c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2462,7 +2462,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -2481,7 +2481,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 704d01c64..070b634b4 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a2a7ddeb..c827d0277 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -206,7 +206,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { p.dispatcher.queuePacket(r, ep, id, pkt) } @@ -217,7 +217,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 074edded6..0c099e2fd 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -60,7 +60,7 @@ type segment struct { xmitCount uint32 } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 7b1d72cf4..9721f6caf 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -316,7 +316,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -372,7 +372,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: s, }) } @@ -380,7 +380,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) { // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegment(payload, h), }) } @@ -389,7 +389,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) } @@ -564,7 +564,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 647b2067a..79faa7869 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -921,7 +921,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: ProtocolNumber, + TTL: ttl, + TOS: tos, + }, &stack.PacketBuffer{ Header: hdr, Data: data, TransportHeader: buffer.View(udp), @@ -1269,7 +1273,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { @@ -1336,7 +1340,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { e.mu.RLock() defer e.mu.RUnlock() diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a674ceb68..7abfa0ed2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, @@ -61,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - pkt stack.PacketBuffer + pkt *stack.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..e320c5758 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,7 +66,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { // Get the header then trim it from the view. h, ok := pkt.Data.PullUp(header.UDPMinimumSize) if !ok { @@ -140,7 +140,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -177,7 +177,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload, }) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 8acaa607a..e8ade882b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -440,7 +440,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -487,7 +487,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), -- cgit v1.2.3 From 41da7a568b1e4f46b3bc09724996556fb18b4d16 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Thu, 4 Jun 2020 15:38:33 -0700 Subject: Fix copylocks error about copying IPTables. IPTables.connections contains a sync.RWMutex. Copying it will trigger copylocks analysis. Tested by manually enabling nogo tests. sync.RWMutex is added to IPTables for the additional race condition discovered. PiperOrigin-RevId: 314817019 --- pkg/sentry/socket/netfilter/netfilter.go | 36 ++++++++++++--------------- pkg/sentry/socket/netstack/stack.go | 9 +++---- pkg/tcpip/stack/iptables.go | 42 +++++++++++++++++++++++++++----- pkg/tcpip/stack/iptables_types.go | 15 ++++++++---- pkg/tcpip/stack/stack.go | 23 ++++------------- pkg/tcpip/transport/icmp/endpoint.go | 5 ---- pkg/tcpip/transport/packet/endpoint.go | 5 ---- pkg/tcpip/transport/raw/endpoint.go | 5 ---- pkg/tcpip/transport/tcp/endpoint.go | 5 ---- pkg/tcpip/transport/udp/endpoint.go | 5 ---- runsc/boot/loader.go | 2 +- 11 files changed, 71 insertions(+), 81 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 47ff48c00..66015e2bc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -144,31 +144,27 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen } func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - ipt := stk.IPTables() - table, ok := ipt.Tables[tablename.String()] + table, ok := stk.IPTables().GetTable(tablename.String()) if !ok { return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } -// FillDefaultIPTables sets stack's IPTables to the default tables and -// populates them with metadata. -func FillDefaultIPTables(stk *stack.Stack) { - ipt := stack.DefaultTables() - - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range ipt.Tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func FillIPTablesMetadata(stk *stack.Stack) { + stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { + // In order to fill in the metadata, we have to translate ipt from its + // netstack format to Linux's giant-binary-blob format. + for name, table := range tables { + _, metadata, err := convertNetstackToBinary(name, table) + if err != nil { + panic(fmt.Errorf("Unable to set default IP tables: %v", err)) + } + table.SetMetadata(metadata) + tables[name] = table } - table.SetMetadata(metadata) - ipt.Tables[name] = table - } - - stk.SetIPTables(ipt) + }) } // convertNetstackToBinary converts the iptables as stored in netstack to the @@ -573,15 +569,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stk.IPTables() table.SetMetadata(metadata{ HookEntry: replace.HookEntry, Underflow: replace.Underflow, NumEntries: replace.NumEntries, Size: replace.Size, }) - ipt.Tables[replace.Name.String()] = table - stk.SetIPTables(ipt) + stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f5fa18136..9b44c2b89 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -362,14 +362,13 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (stack.IPTables, error) { +func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillDefaultIPTables sets the stack's iptables to the default tables, which -// allow and do not modify all traffic. -func (s *Stack) FillDefaultIPTables() { - netfilter.FillDefaultIPTables(s.Stack) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func (s *Stack) FillIPTablesMetadata() { + netfilter.FillIPTablesMetadata(s.Stack) } // Resume implements inet.Stack.Resume. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index d989dbe91..4e9b404c8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -43,11 +43,11 @@ const HookUnset = -1 // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() IPTables { +func DefaultTables() *IPTables { // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for // iotas. - return IPTables{ - Tables: map[string]Table{ + return &IPTables{ + tables: map[string]Table{ TablenameNat: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -106,7 +106,7 @@ func DefaultTables() IPTables { UserChains: map[string]int{}, }, }, - Priorities: map[Hook][]string{ + priorities: map[Hook][]string{ Input: []string{TablenameNat, TablenameFilter}, Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, @@ -158,6 +158,36 @@ func EmptyNatTable() Table { } } +// GetTable returns table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + it.mu.RLock() + defer it.mu.RUnlock() + t, ok := it.tables[name] + return t, ok +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) { + it.mu.Lock() + defer it.mu.Unlock() + it.tables[name] = table +} + +// ModifyTables acquires write-lock and calls fn with internal name-to-table +// map. This function can be used to update multiple tables atomically. +func (it *IPTables) ModifyTables(fn func(map[string]Table)) { + it.mu.Lock() + defer it.mu.Unlock() + fn(it.tables) +} + +// GetPriorities returns slice of priorities associated with hook. +func (it *IPTables) GetPriorities(hook Hook) []string { + it.mu.RLock() + defer it.mu.RUnlock() + return it.priorities[hook] +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -184,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr it.connections.HandlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] + for _, tablename := range it.GetPriorities(hook) { + table, _ := it.GetTable(tablename) ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index af72b9c46..4a6a5c6f1 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -16,6 +16,7 @@ package stack import ( "strings" + "sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,17 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table + // mu protects tables and priorities. + mu sync.RWMutex - // Priorities maps each hook to a list of table names. The order of the + // tables maps table names to tables. User tables have arbitrary names. mu + // needs to be locked for accessing. + tables map[string]Table + + // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string + // hook. mu needs to be locked for accessing. + priorities map[Hook][]string connections ConnTrackTable } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8af06cb9a..294ce8775 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -424,12 +424,8 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool - // tablesMu protects iptables. - tablesMu sync.RWMutex - - // tables are the iptables packet filtering and manipulation rules. The are - // protected by tablesMu.` - tables IPTables + // tables are the iptables packet filtering and manipulation rules. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -676,6 +672,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, @@ -1741,18 +1738,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() IPTables { - s.tablesMu.RLock() - t := s.tables - s.tablesMu.RUnlock() - return t -} - -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt IPTables) { - s.tablesMu.Lock() - s.tables = ipt - s.tablesMu.Unlock() +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 29ff68df3..3bc72bc19 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -140,11 +140,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index bab2d63ae..baf08eda6 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -132,11 +132,6 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (stack.IPTables, error) { - return ep.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 25a17940d..21c34fac2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !e.associated { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d048ef90c..edca98160 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1172,11 +1172,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 79faa7869..663af8fec 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -247,11 +247,6 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index f802bc9fb..002479612 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1056,7 +1056,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) } - s.FillDefaultIPTables() + s.FillIPTablesMetadata() return &s, nil } -- cgit v1.2.3 From 28b8a5cc3ac538333756084da28d7f13f13b5c87 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 18 Jun 2020 17:00:47 -0700 Subject: iptables: remove metadata struct Metadata was useful for debugging and safety, but enough tests exist that we should see failures when (de)serialization is broken. It made stack initialization more cumbersome and it's also getting in the way of ip6tables. PiperOrigin-RevId: 317210653 --- pkg/sentry/socket/netfilter/netfilter.go | 103 ++++++------------------------- pkg/sentry/socket/netstack/stack.go | 6 -- pkg/tcpip/stack/iptables.go | 8 --- pkg/tcpip/stack/iptables_types.go | 16 +---- runsc/boot/loader.go | 2 - 5 files changed, 21 insertions(+), 114 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 66015e2bc..f7abe77d3 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -41,19 +41,6 @@ const errorTargetName = "ERROR" // change the destination port/destination IP for packets. const redirectTargetName = "REDIRECT" -// Metadata is used to verify that we are correctly serializing and -// deserializing iptables into structs consumable by the iptables tool. We save -// a metadata struct when the tables are written, and when they are read out we -// verify that certain fields are the same. -// -// metadata is used by this serialization/deserializing code, not netstack. -type metadata struct { - HookEntry [linux.NF_INET_NUMHOOKS]uint32 - Underflow [linux.NF_INET_NUMHOOKS]uint32 - NumEntries uint32 - Size uint32 -} - // enableLogging controls whether to log the (de)serialization of netfilter // structs between userspace and netstack. These logs are useful when // developing iptables, but can pollute sentry logs otherwise. @@ -83,29 +70,13 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT return linux.IPTGetinfo{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, info.Name) + _, info, err := convertNetstackToBinary(stack, info.Name) if err != nil { - nflog("%v", err) + nflog("couldn't convert iptables: %v", err) return linux.IPTGetinfo{}, syserr.ErrInvalidArgument } - // Get the hooks that apply to this table. - info.ValidHooks = table.ValidHooks() - - // Grab the metadata struct, which is used to store information (e.g. - // the number of entries) that applies to the user's encoding of - // iptables, but not netstack's. - metadata := table.Metadata().(metadata) - - // Set values from metadata. - info.HookEntry = metadata.HookEntry - info.Underflow = metadata.Underflow - info.NumEntries = metadata.NumEntries - info.Size = metadata.Size - nflog("returning info: %+v", info) - return info, nil } @@ -118,23 +89,13 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return linux.KernelIPTGetEntries{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, userEntries.Name) - if err != nil { - nflog("%v", err) - return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument - } - // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table) + entries, _, err := convertNetstackToBinary(stack, userEntries.Name) if err != nil { nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if meta != table.Metadata().(metadata) { - panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta)) - } if binary.Size(entries) > uintptr(outLen) { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument @@ -143,44 +104,26 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - table, ok := stk.IPTables().GetTable(tablename.String()) - if !ok { - return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) - } - return table, nil -} - -// FillIPTablesMetadata populates stack's IPTables with metadata. -func FillIPTablesMetadata(stk *stack.Stack) { - stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) - } - table.SetMetadata(metadata) - tables[name] = table - } - }) -} - // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) { - // Return values. +func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + table, ok := stack.IPTables().GetTable(tablename.String()) + if !ok { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + var entries linux.KernelIPTGetEntries - var meta metadata + var info linux.IPTGetinfo + info.ValidHooks = table.ValidHooks() // The table name has to fit in the struct. if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("table name %q too long.", tablename) + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - copy(entries.Name[:], tablename) + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], tablename[:]) for ruleIdx, rule := range table.Rules { nflog("convert to binary: current offset: %d", entries.Size) @@ -189,14 +132,14 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI for hook, hookRuleIdx := range table.BuiltinChains { if hookRuleIdx == ruleIdx { nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) - meta.HookEntry[hook] = entries.Size + info.HookEntry[hook] = entries.Size } } // Is this a chain underflow point? for underflow, underflowRuleIdx := range table.Underflows { if underflowRuleIdx == ruleIdx { nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) - meta.Underflow[underflow] = entries.Size + info.Underflow[underflow] = entries.Size } } @@ -251,12 +194,12 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI entries.Size += uint32(entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) - meta.NumEntries++ + info.NumEntries++ } - nflog("convert to binary: finished with an marshalled size of %d", meta.Size) - meta.Size = entries.Size - return entries, meta, nil + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + info.Size = entries.Size + return entries, info, nil } func marshalTarget(target stack.Target) []byte { @@ -569,12 +512,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - table.SetMetadata(metadata{ - HookEntry: replace.HookEntry, - Underflow: replace.Underflow, - NumEntries: replace.NumEntries, - Size: replace.Size, - }) stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f97f9b6f3..ee11742a6 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -366,11 +365,6 @@ func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillIPTablesMetadata populates stack's IPTables with metadata. -func (s *Stack) FillIPTablesMetadata() { - netfilter.FillIPTablesMetadata(s.Stack) -} - // Resume implements inet.Stack.Resume. func (s *Stack) Resume() { s.Stack.Resume() diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 4e9b404c8..dc2b77c9d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -173,14 +173,6 @@ func (it *IPTables) ReplaceTable(name string, table Table) { it.tables[name] = table } -// ModifyTables acquires write-lock and calls fn with internal name-to-table -// map. This function can be used to update multiple tables atomically. -func (it *IPTables) ModifyTables(fn func(map[string]Table)) { - it.mu.Lock() - defer it.mu.Unlock() - fn(it.tables) -} - // GetPriorities returns slice of priorities associated with hook. func (it *IPTables) GetPriorities(hook Hook) []string { it.mu.RLock() diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4a6a5c6f1..72f1dd329 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -95,7 +95,7 @@ type IPTables struct { } // A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules with some metadata for entrypoints and such. +// really just a list of rules. type Table struct { // Rules holds the rules that make up the table. Rules []Rule @@ -110,10 +110,6 @@ type Table struct { // UserChains holds user-defined chains for the keyed by name. Users // can give their chains arbitrary names. UserChains map[string]int - - // Metadata holds information about the Table that is useful to users - // of IPTables, but not to the netstack IPTables code itself. - metadata interface{} } // ValidHooks returns a bitmap of the builtin hooks for the given table. @@ -125,16 +121,6 @@ func (table *Table) ValidHooks() uint32 { return hooks } -// Metadata returns the metadata object stored in table. -func (table *Table) Metadata() interface{} { - return table.metadata -} - -// SetMetadata sets the metadata object stored in table. -func (table *Table) SetMetadata(metadata interface{}) { - table.metadata = metadata -} - // A Rule is a packet processing rule. It consists of two pieces. First it // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index c6efcdc83..081db39c1 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1071,8 +1071,6 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) } - s.FillIPTablesMetadata() - return &s, nil } -- cgit v1.2.3 From 0c169b6ad598200a57db7bf0f679da1d6cb395c4 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 18 Jun 2020 19:45:13 -0700 Subject: iptables: skip iptables if no rules are set Users that never set iptables rules shouldn't incur the iptables performance cost. Suggested by Ian (@iangudger). PiperOrigin-RevId: 317232921 --- pkg/tcpip/stack/iptables.go | 10 ++++++++++ pkg/tcpip/stack/iptables_types.go | 11 ++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index dc2b77c9d..62d4eb1b6 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -170,6 +170,7 @@ func (it *IPTables) GetTable(name string) (Table, bool) { func (it *IPTables) ReplaceTable(name string, table Table) { it.mu.Lock() defer it.mu.Unlock() + it.modified = true it.tables[name] = table } @@ -201,6 +202,15 @@ const ( // // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { + // Many users never configure iptables. Spare them the cost of rule + // traversal if rules have never been set. + it.mu.RLock() + if !it.modified { + it.mu.RUnlock() + return true + } + it.mu.RUnlock() + // Packets are manipulated only if connection and matching // NAT rule exists. it.connections.HandlePacket(pkt, hook, gso, r) diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 72f1dd329..7026990c4 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -79,11 +79,11 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // mu protects tables and priorities. + // mu protects tables, priorities, and modified. mu sync.RWMutex - // tables maps table names to tables. User tables have arbitrary names. mu - // needs to be locked for accessing. + // tables maps table names to tables. User tables have arbitrary names. + // mu needs to be locked for accessing. tables map[string]Table // priorities maps each hook to a list of table names. The order of the @@ -91,6 +91,11 @@ type IPTables struct { // hook. mu needs to be locked for accessing. priorities map[Hook][]string + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + modified bool + connections ConnTrackTable } -- cgit v1.2.3 From 7fb6cc286fffab2f408a5dbc228e0db706104682 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 25 Jun 2020 21:20:29 -0700 Subject: conntrack refactor, no behavior changes - Split connTrackForPacket into 2 functions instead of switching on flag - Replace hash with struct keys. - Remove prefixes where possible - Remove unused connStatus, timeout - Flatten ConnTrack struct a bit - some intermediate structs had no meaning outside of the context of their parent. - Protect conn.tcb with a mutex - Remove redundant error checking (e.g. when is pkt.NetworkHeader valid) - Clarify that HandlePacket and CreateConnFor are the expected entrypoints for ConnTrack PiperOrigin-RevId: 318407168 --- pkg/sentry/socket/netfilter/targets.go | 2 +- pkg/tcpip/stack/conntrack.go | 405 ++++++++++++--------------------- pkg/tcpip/stack/iptables.go | 7 +- pkg/tcpip/stack/iptables_targets.go | 23 +- pkg/tcpip/stack/iptables_types.go | 4 +- 5 files changed, 168 insertions(+), 273 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 84abe8d29..b91ba3ab3 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -30,6 +30,6 @@ type JumpTarget struct { } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 05bf62788..af9c325ca 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -15,39 +15,29 @@ package stack import ( - "encoding/binary" "sync" - "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" ) // Connection tracking is used to track and manipulate packets for NAT rules. -// The connection is created for a packet if it does not exist. Every connection -// contains two tuples (original and reply). The tuples are manipulated if there -// is a matching NAT rule. The packet is modified by looking at the tuples in the -// Prerouting and Output hooks. +// The connection is created for a packet if it does not exist. Every +// connection contains two tuples (original and reply). The tuples are +// manipulated if there is a matching NAT rule. The packet is modified by +// looking at the tuples in the Prerouting and Output hooks. +// +// Currently, only TCP tracking is supported. // Direction of the tuple. -type ctDirection int +type direction int const ( - dirOriginal ctDirection = iota + dirOriginal direction = iota dirReply ) -// Status of connection. -// TODO(gvisor.dev/issue/170): Add other states of connection. -type connStatus int - -const ( - connNew connStatus = iota - connEstablished -) - // Manipulation type for the connection. type manipType int @@ -56,250 +46,171 @@ const ( manipDstOutput ) -// connTrackMutable is the manipulatable part of the tuple. -type connTrackMutable struct { - // addr is source address of the tuple. - addr tcpip.Address +// tuple holds a connection's identifying and manipulating data in one +// direction. It is immutable. +type tuple struct { + tupleID - // port is source port of the tuple. - port uint16 + // conn is the connection tracking entry this tuple belongs to. + conn *conn - // protocol is network layer protocol. - protocol tcpip.NetworkProtocolNumber + // direction is the direction of the tuple. + direction direction } -// connTrackImmutable is the non-manipulatable part of the tuple. -type connTrackImmutable struct { - // addr is destination address of the tuple. - addr tcpip.Address - - // direction is direction (original or reply) of the tuple. - direction ctDirection - - // port is destination port of the tuple. - port uint16 - - // protocol is transport layer protocol. - protocol tcpip.TransportProtocolNumber +// tupleID uniquely identifies a connection in one direction. It currently +// contains enough information to distinguish between any TCP or UDP +// connection, and will need to be extended to support other protocols. +type tupleID struct { + srcAddr tcpip.Address + srcPort uint16 + dstAddr tcpip.Address + dstPort uint16 + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber } -// connTrackTuple represents the tuple which is created from the -// packet. -type connTrackTuple struct { - // dst is non-manipulatable part of the tuple. - dst connTrackImmutable - - // src is manipulatable part of the tuple. - src connTrackMutable +// reply creates the reply tupleID. +func (ti tupleID) reply() tupleID { + return tupleID{ + srcAddr: ti.dstAddr, + srcPort: ti.dstPort, + dstAddr: ti.srcAddr, + dstPort: ti.srcPort, + transProto: ti.transProto, + netProto: ti.netProto, + } } -// connTrackTupleHolder is the container of tuple and connection. -type ConnTrackTupleHolder struct { - // conn is pointer to the connection tracking entry. - conn *connTrack - - // tuple is original or reply tuple. - tuple connTrackTuple -} +// conn is a tracked connection. +type conn struct { + // original is the tuple in original direction. It is immutable. + original tuple -// connTrack is the connection. -type connTrack struct { - // originalTupleHolder contains tuple in original direction. - originalTupleHolder ConnTrackTupleHolder + // reply is the tuple in reply direction. It is immutable. + reply tuple - // replyTupleHolder contains tuple in reply direction. - replyTupleHolder ConnTrackTupleHolder + // manip indicates if the packet should be manipulated. It is immutable. + manip manipType - // status indicates connection is new or established. - status connStatus + // tcbHook indicates if the packet is inbound or outbound to + // update the state of tcb. It is immutable. + tcbHook Hook - // timeout indicates the time connection should be active. - timeout time.Duration - - // manip indicates if the packet should be manipulated. - manip manipType + // mu protects tcb. + mu sync.Mutex // tcb is TCB control block. It is used to keep track of states - // of tcp connection. + // of tcp connection and is protected by mu. tcb tcpconntrack.TCB - - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. - tcbHook Hook } -// ConnTrackTable contains a map of all existing connections created for -// NAT rules. -type ConnTrackTable struct { - // connMu protects connTrackTable. - connMu sync.RWMutex - - // connTrackTable maintains a map of tuples needed for connection tracking - // for iptables NAT rules. The key for the map is an integer calculated - // using seed, source address, destination address, source port and - // destination port. - CtMap map[uint32]ConnTrackTupleHolder - - // seed is a one-time random value initialized at stack startup - // and is used in calculation of hash key for connection tracking - // table. - Seed uint32 -} +// ConnTrack tracks all connections created for NAT rules. Most users are +// expected to only call handlePacket and createConnFor. +type ConnTrack struct { + // mu protects conns. + mu sync.RWMutex -// packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { - var tuple connTrackTuple + // conns maintains a map of tuples needed for connection tracking for + // iptables NAT rules. It is protected by mu. + conns map[tupleID]tuple +} - netHeader := header.IPv4(pkt.NetworkHeader) +// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid +// TCP header. +func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. + netHeader := header.IPv4(pkt.NetworkHeader) if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } tcpHeader := header.TCP(pkt.TransportHeader) if tcpHeader == nil { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } - tuple.src.addr = netHeader.SourceAddress() - tuple.src.port = tcpHeader.SourcePort() - tuple.src.protocol = header.IPv4ProtocolNumber - - tuple.dst.addr = netHeader.DestinationAddress() - tuple.dst.port = tcpHeader.DestinationPort() - tuple.dst.protocol = netHeader.TransportProtocol() - - return tuple, nil -} - -// getReplyTuple creates reply tuple for the given tuple. -func getReplyTuple(tuple connTrackTuple) connTrackTuple { - var replyTuple connTrackTuple - replyTuple.src.addr = tuple.dst.addr - replyTuple.src.port = tuple.dst.port - replyTuple.src.protocol = tuple.src.protocol - replyTuple.dst.addr = tuple.src.addr - replyTuple.dst.port = tuple.src.port - replyTuple.dst.protocol = tuple.dst.protocol - replyTuple.dst.direction = dirReply - - return replyTuple + return tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: tcpHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: tcpHeader.DestinationPort(), + transProto: netHeader.TransportProtocol(), + netProto: header.IPv4ProtocolNumber, + }, nil } -// makeNewConn creates new connection. -func makeNewConn(tuple, replyTuple connTrackTuple) connTrack { - var conn connTrack - conn.status = connNew - conn.originalTupleHolder.tuple = tuple - conn.originalTupleHolder.conn = &conn - conn.replyTupleHolder.tuple = replyTuple - conn.replyTupleHolder.conn = &conn - - return conn -} - -// getTupleHash returns hash of the tuple. The fields used for -// generating hash are seed (generated once for stack), source address, -// destination address, source port and destination ports. -func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { - h := jenkins.Sum32(ct.Seed) - h.Write([]byte(tuple.src.addr)) - h.Write([]byte(tuple.dst.addr)) - portBuf := make([]byte, 2) - binary.LittleEndian.PutUint16(portBuf, tuple.src.port) - h.Write([]byte(portBuf)) - binary.LittleEndian.PutUint16(portBuf, tuple.dst.port) - h.Write([]byte(portBuf)) - - return h.Sum32() +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn } -// connTrackForPacket returns connTrack for packet. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other -// transport protocols. -func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - var dir ctDirection - tuple, err := packetToTuple(pkt, hook) +// connFor gets the conn for pkt if it exists, or returns nil +// if it does not. It returns an error when pkt does not contain a valid TCP +// header. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support +// other transport protocols. +func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { + tid, err := packetToTupleID(pkt) if err != nil { - return nil, dir + return nil, dirOriginal } - ct.connMu.Lock() - defer ct.connMu.Unlock() - - connTrackTable := ct.CtMap - hash := ct.getTupleHash(tuple) - - var conn *connTrack - switch createConn { - case true: - // If connection does not exist for the hash, create a new - // connection. - replyTuple := getReplyTuple(tuple) - replyHash := ct.getTupleHash(replyTuple) - newConn := makeNewConn(tuple, replyTuple) - conn = &newConn - - // Add tupleHolders to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked list. - ct.CtMap[hash] = conn.originalTupleHolder - ct.CtMap[replyHash] = conn.replyTupleHolder - default: - tupleHolder, ok := connTrackTable[hash] - if !ok { - return nil, dir - } - - // If this is the reply of new connection, set the connection - // status as ESTABLISHED. - conn = tupleHolder.conn - if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply { - conn.status = connEstablished - } - if tupleHolder.conn == nil { - panic("tupleHolder has null connection tracking entry") - } + ct.mu.Lock() + defer ct.mu.Unlock() - dir = tupleHolder.tuple.dst.direction + tuple, ok := ct.conns[tid] + if !ok { + return nil, dirOriginal } - return conn, dir + return tuple.conn, tuple.direction } -// SetNatInfo will manipulate the tuples according to iptables NAT rules. -func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) { - // Get the connection. Connection is always created before this - // function is called. - conn, _ := ct.connTrackForPacket(pkt, hook, false) - if conn == nil { - panic("connection should be created to manipulate tuples.") +// createConnFor creates a new conn for pkt. +func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Prerouting && hook != Output { + return nil } - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) - - // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to - // support changing of address for Prerouting. - // Change the port as per the iptables rule. This tuple will be used - // to manipulate the packet in HandlePacket. - conn.replyTupleHolder.tuple.src.addr = rt.MinIP - conn.replyTupleHolder.tuple.src.port = rt.MinPort - newHash := ct.getTupleHash(conn.replyTupleHolder.tuple) + // Create a new connection and change the port as per the iptables + // rule. This tuple will be used to manipulate the packet in + // handlePacket. + replyTID := tid.reply() + replyTID.srcAddr = rt.MinIP + replyTID.srcPort = rt.MinPort + var manip manipType + switch hook { + case Prerouting: + manip = manipDstPrerouting + case Output: + manip = manipDstOutput + } + conn := newConn(tid, replyTID, manip, hook) // Add the changed tuple to the map. - ct.connMu.Lock() - defer ct.connMu.Unlock() - ct.CtMap[newHash] = conn.replyTupleHolder - if hook == Output { - conn.replyTupleHolder.conn.manip = manipDstOutput - } + // TODO(gvisor.dev/issue/170): Need to support collisions using linked + // list. + ct.mu.Lock() + defer ct.mu.Unlock() + ct.conns[tid] = conn.original + ct.conns[replyTID] = conn.reply - // Delete the old tuple. - delete(ct.CtMap, replyHash) + return conn } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.. -func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) { +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. +func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -308,13 +219,13 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) // modified. switch dir { case dirOriginal: - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) case dirReply: - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } netHeader.SetChecksum(0) @@ -322,7 +233,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) } // handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) { +func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -331,13 +242,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, // modified. For prerouting redirection, we only reach this point // when replying, so packet sources are modified. if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) } else { - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } // Calculate the TCP checksum and set it. @@ -356,9 +267,9 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, netHeader.SetChecksum(^netHeader.CalculateChecksum()) } -// HandlePacket will manipulate the port and address of the packet if the +// handlePacket will manipulate the port and address of the packet if the // connection exists. -func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { if pkt.NatDone { return } @@ -367,21 +278,9 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r return } - conn, dir := ct.connTrackForPacket(pkt, hook, false) - // Connection or Rule not found for the packet. + conn, dir := ct.connFor(pkt) if conn == nil { - return - } - - netHeader := header.IPv4(pkt.NetworkHeader) - // TODO(gvisor.dev/issue/170): Need to support for other transport - // protocols as well. - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return - } - - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + // Connection not found for the packet or the packet is invalid. return } @@ -396,7 +295,10 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle // other tcp states. + conn.mu.Lock() + defer conn.mu.Unlock() var st tcpconntrack.Result + tcpHeader := header.TCP(pkt.TransportHeader) if conn.tcb.IsEmpty() { conn.tcb.Init(tcpHeader) conn.tcbHook = hook @@ -409,26 +311,21 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r } } - // Delete conntrack if tcp connection is closed. + // Delete conn if tcp connection is closed. if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConnTrack(conn) + ct.deleteConn(conn) } } -// deleteConnTrack deletes the connection. -func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) { +// deleteConn deletes the connection. +func (ct *ConnTrack) deleteConn(conn *conn) { if conn == nil { return } - tuple := conn.originalTupleHolder.tuple - hash := ct.getTupleHash(tuple) - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) - - ct.connMu.Lock() - defer ct.connMu.Unlock() + ct.mu.Lock() + defer ct.mu.Unlock() - delete(ct.CtMap, hash) - delete(ct.CtMap, replyHash) + delete(ct.conns, conn.original.tupleID) + delete(ct.conns, conn.reply.tupleID) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 62d4eb1b6..974d77c36 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -111,9 +111,8 @@ func DefaultTables() *IPTables { Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, - connections: ConnTrackTable{ - CtMap: make(map[uint32]ConnTrackTupleHolder), - Seed: generateRandUint32(), + connections: ConnTrack{ + conns: make(map[tupleID]tuple), }, } } @@ -213,7 +212,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.HandlePacket(pkt, hook, gso, r) + it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. for _, tablename := range it.GetPriorities(hook) { diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 92e31643e..d43f60c67 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -24,7 +24,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -32,7 +32,7 @@ func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, t type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -41,7 +41,7 @@ func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcp type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -52,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -61,7 +61,7 @@ func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -92,7 +92,7 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -150,12 +150,11 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set up conection for matching NAT rule. - // Only the first packet of the connection comes here. - // Other packets will be manipulated in connection tracking. - if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil { - ct.SetNatInfo(pkt, rt, hook) - ct.HandlePacket(pkt, hook, gso, r) + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.createConnFor(pkt, hook, rt); conn != nil { + ct.handlePacket(pkt, hook, gso, r) } default: return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 7026990c4..c528ec381 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -96,7 +96,7 @@ type IPTables struct { // don't utilize iptables. modified bool - connections ConnTrackTable + connections ConnTrack } // A Table defines a set of chains and hooks into the network stack. It is @@ -249,5 +249,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) + Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) } -- cgit v1.2.3 From 43c209f48e0aa9024705583cc6f0fafa7d6380ca Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 13 Jul 2020 11:59:26 -0700 Subject: garbage collect connections As in Linux, we must periodically clean up unused connections. PiperOrigin-RevId: 321003353 --- pkg/tcpip/stack/BUILD | 14 ++ pkg/tcpip/stack/conntrack.go | 277 ++++++++++++++++++---- pkg/tcpip/stack/iptables.go | 42 +++- pkg/tcpip/stack/iptables_state.go | 40 ++++ pkg/tcpip/stack/iptables_types.go | 11 + pkg/tcpip/stack/stack.go | 1 + pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go | 5 + 7 files changed, 347 insertions(+), 43 deletions(-) create mode 100644 pkg/tcpip/stack/iptables_state.go (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index e65c731c2..800bf3f08 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -27,6 +27,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ @@ -35,6 +47,7 @@ go_library( "forwarder.go", "icmp_rate_limit.go", "iptables.go", + "iptables_state.go", "iptables_targets.go", "iptables_types.go", "linkaddrcache.go", @@ -50,6 +63,7 @@ go_library( "stack_global_state.go", "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index af9c325ca..d39baf620 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -15,9 +15,12 @@ package stack import ( + "encoding/binary" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" ) @@ -30,6 +33,10 @@ import ( // // Currently, only TCP tracking is supported. +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 + // Direction of the tuple. type direction int @@ -48,7 +55,12 @@ const ( // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. +// +// +stateify savable type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry + tupleID // conn is the connection tracking entry this tuple belongs to. @@ -61,6 +73,8 @@ type tuple struct { // tupleID uniquely identifies a connection in one direction. It currently // contains enough information to distinguish between any TCP or UDP // connection, and will need to be extended to support other protocols. +// +// +stateify savable type tupleID struct { srcAddr tcpip.Address srcPort uint16 @@ -83,6 +97,8 @@ func (ti tupleID) reply() tupleID { } // conn is a tracked connection. +// +// +stateify savable type conn struct { // original is the tuple in original direction. It is immutable. original tuple @@ -98,22 +114,67 @@ type conn struct { tcbHook Hook // mu protects tcb. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB + + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` +} + +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout } // ConnTrack tracks all connections created for NAT rules. Most users are // expected to only call handlePacket and createConnFor. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable type ConnTrack struct { - // mu protects conns. - mu sync.RWMutex + // seed is a one-time random value initialized at stack startup + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 - // conns maintains a map of tuples needed for connection tracking for - // iptables NAT rules. It is protected by mu. - conns map[tupleID]tuple + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` + + // buckets is protected by mu. + buckets []bucket +} + +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList } // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid @@ -143,8 +204,9 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // newConn creates new connection. func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { conn := conn{ - manip: manip, - tcbHook: hook, + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), } conn.original = tuple{conn: &conn, tupleID: orig} conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} @@ -162,14 +224,28 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { return nil, dirOriginal } - ct.mu.Lock() - defer ct.mu.Unlock() - - tuple, ok := ct.conns[tid] - if !ok { - return nil, dirOriginal + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } } - return tuple.conn, tuple.direction + + return nil, dirOriginal } // createConnFor creates a new conn for pkt. @@ -197,13 +273,31 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg } conn := newConn(tid, replyTID, manip, hook) - // Add the changed tuple to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked - // list. - ct.mu.Lock() - defer ct.mu.Unlock() - ct.conns[tid] = conn.original - ct.conns[replyTID] = conn.reply + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(tid) + replyBucket := ct.bucket(replyTID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() + } + + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } return conn } @@ -297,35 +391,134 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() - var st tcpconntrack.Result - tcpHeader := header.TCP(pkt.TransportHeader) - if conn.tcb.IsEmpty() { + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() { conn.tcb.Init(tcpHeader) conn.tcbHook = hook + } else if hook == conn.tcbHook { + conn.tcb.UpdateStateOutbound(tcpHeader) } else { - switch hook { - case conn.tcbHook: - st = conn.tcb.UpdateStateOutbound(tcpHeader) - default: - st = conn.tcb.UpdateStateInbound(tcpHeader) - } + conn.tcb.UpdateStateInbound(tcpHeader) } +} + +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} - // Delete conn if tcp connection is closed. - if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConn(conn) +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } + } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval } + // We've hit the maximum interval. + return idx, maxInterval } -// deleteConn deletes the connection. -func (ct *ConnTrack) deleteConn(conn *conn) { - if conn == nil { - return +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: ct.mu is locked for reading and bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false } - ct.mu.Lock() - defer ct.mu.Unlock() + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } + + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) + } + ct.buckets[bucket].tuples.Remove(tuple) + + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() + } - delete(ct.conns, conn.original.tupleID) - delete(ct.conns, conn.reply.tupleID) + return true } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 974d77c36..f846ea2e5 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -41,6 +42,9 @@ const ( // underflow. const HookUnset = -1 +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() *IPTables { @@ -112,8 +116,9 @@ func DefaultTables() *IPTables { Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, connections: ConnTrack{ - conns: make(map[tupleID]tuple), + seed: generateRandUint32(), }, + reaperDone: make(chan struct{}, 1), } } @@ -169,6 +174,12 @@ func (it *IPTables) GetTable(name string) (Table, bool) { func (it *IPTables) ReplaceTable(name string, table Table) { it.mu.Lock() defer it.mu.Unlock() + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } it.modified = true it.tables[name] = table } @@ -249,6 +260,35 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr return true } +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index c528ec381..eb70e3104 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -78,6 +78,8 @@ const ( ) // IPTables holds all the tables for a netstack. +// +// +stateify savable type IPTables struct { // mu protects tables, priorities, and modified. mu sync.RWMutex @@ -97,10 +99,15 @@ type IPTables struct { modified bool connections ConnTrack + + // reaperDone can be signalled to stop the reaper goroutine. + reaperDone chan struct{} } // A Table defines a set of chains and hooks into the network stack. It is // really just a list of rules. +// +// +stateify savable type Table struct { // Rules holds the rules that make up the table. Rules []Rule @@ -130,6 +137,8 @@ func (table *Table) ValidHooks() uint32 { // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. +// +// +stateify savable type Rule struct { // Filter holds basic IP filtering fields common to every rule. Filter IPHeaderFilter @@ -142,6 +151,8 @@ type Rule struct { } // IPHeaderFilter holds basic IP filtering data common to every rule. +// +// +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cdcfb8321..0aa815447 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -425,6 +425,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. + // TODO(gvisor.dev/issue/170): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { -- cgit v1.2.3 From e92f38ff0cd2e490637df2081fc8f75ddaf32937 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 15 Jul 2020 16:33:46 -0700 Subject: iptables: remove check for NetworkHeader This is no longer necessary, as we always set NetworkHeader before calling iptables.Check. PiperOrigin-RevId: 321461978 --- pkg/tcpip/stack/iptables.go | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f846ea2e5..bbf3b60e8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -292,10 +292,9 @@ func (it *IPTables) startReaper(interval time.Duration) { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -319,9 +318,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * return drop, natPkts } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -366,23 +365,12 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. - if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } - } - // Check whether the packet matches the IP header filter. if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. -- cgit v1.2.3 From bd98f820141208d9f19b0e12dee93f6f6de3ac97 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 22 Jul 2020 16:22:06 -0700 Subject: iptables: replace maps with arrays For iptables users, Check() is a hot path called for every packet one or more times. Let's avoid a bunch of map lookups. PiperOrigin-RevId: 322678699 --- pkg/sentry/socket/netfilter/netfilter.go | 21 ++-- pkg/tcpip/stack/iptables.go | 187 ++++++++++++++++++------------- pkg/tcpip/stack/iptables_types.go | 22 ++-- 3 files changed, 127 insertions(+), 103 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index f7abe77d3..1243143ea 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -342,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.TablenameFilter: + case stack.FilterTable: table = stack.EmptyFilterTable() - case stack.TablenameNat: - table = stack.EmptyNatTable() + case stack.NATTable: + table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -431,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1< Date: Fri, 5 Jun 2020 11:22:44 -0700 Subject: iptables: don't NAT existing connections Fixes a NAT bug that manifested as: - A SYN was sent from gVisor to another host, unaffected by iptables. - The corresponding SYN/ACK was NATted by a PREROUTING REDIRECT rule despite being part of the existing connection. - The socket that sent the SYN never received the SYN/ACK and thus a connection could not be established. We handle this (as Linux does) by tracking all connections, inserting a no-op conntrack rule for new connections with no rules of their own. Needed for istio support (#170). --- pkg/tcpip/stack/conntrack.go | 136 +++++++++++++++++++++++++++++------- pkg/tcpip/stack/iptables.go | 21 +++++- pkg/tcpip/stack/iptables_targets.go | 2 +- test/iptables/iptables_test.go | 7 ++ test/iptables/nat.go | 52 ++++++++++++++ 5 files changed, 189 insertions(+), 29 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index d39baf620..559a1c4dd 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -49,7 +49,8 @@ const ( type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) @@ -113,13 +114,11 @@ type conn struct { // update the state of tcb. It is immutable. tcbHook Hook - // mu protects tcb. + // mu protects all mutable state. mu sync.Mutex `state:"nosave"` - // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB - // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. lastUsed time.Time `state:".(unixTime)"` @@ -141,8 +140,26 @@ func (cn *conn) timedOut(now time.Time) bool { return now.Sub(cn.lastUsed) > defaultTimeout } +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } +} + // ConnTrack tracks all connections created for NAT rules. Most users are -// expected to only call handlePacket and createConnFor. +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. // // ConnTrack keeps all connections in a slice of buckets, each of which holds a // linked list of tuples. This gives us some desirable properties: @@ -248,8 +265,7 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { return nil, dirOriginal } -// createConnFor creates a new conn for pkt. -func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -272,10 +288,15 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg manip = manipDstOutput } conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn +} +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { // Lock the buckets in the correct order. - tupleBucket := ct.bucket(tid) - replyBucket := ct.bucket(replyTID) + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) ct.mu.RLock() defer ct.mu.RUnlock() if tupleBucket < replyBucket { @@ -289,22 +310,37 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg ct.buckets[tupleBucket].mu.Lock() } - // Add the tuple to the map. - ct.buckets[tupleBucket].tuples.PushFront(&conn.original) - ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } + + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + } // Unlocking can happen in any order. ct.buckets[tupleBucket].mu.Unlock() if tupleBucket != replyBucket { ct.buckets[replyBucket].mu.Unlock() } - - return conn } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. // TODO(gvisor.dev/issue/170): Change address for Prerouting hook. func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -322,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { netHeader.SetSourceAddress(conn.original.dstAddr) } + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) } // handlePacketOutput manipulates ports for packets in Output hook. func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -362,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d } // handlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false + } + + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return false } conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. if conn == nil { - // Connection not found for the packet or the packet is invalid. - return + return true + } + + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return false } switch hook { @@ -395,14 +452,39 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else if hook == conn.tcbHook { - conn.tcb.UpdateStateOutbound(tcpHeader) - } else { - conn.tcb.UpdateStateInbound(tcpHeader) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return + } + + // We only track TCP connections. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return + } + + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { + return } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + ct.insertConn(conn) } // bucket gets the conntrack bucket for a tupleID. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 5f647c5fe..ca1dda695 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -245,13 +245,18 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. it.mu.RLock() defer it.mu.RUnlock() priorities := it.priorities[hook] for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } table := it.tables[tableID] ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { @@ -281,6 +286,20 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d43f60c67..dc88033c7 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.createConnFor(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index f5ac79370..f303030aa 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -263,6 +263,13 @@ func TestNATPreRedirectTCPPort(t *testing.T) { singleTest(t, NATPreRedirectTCPPort{}) } +func TestNATPreRedirectTCPOutgoing(t *testing.T) { + singleTest(t, NATPreRedirectTCPOutgoing{}) +} + +func TestNATOutRedirectTCPIncoming(t *testing.T) { + singleTest(t, NATOutRedirectTCPIncoming{}) +} func TestNATOutRedirectUDPPort(t *testing.T) { singleTest(t, NATOutRedirectUDPPort{}) } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 8562b0820..149dec2bb 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -28,6 +28,8 @@ const ( func init() { RegisterTestCase(NATPreRedirectUDPPort{}) RegisterTestCase(NATPreRedirectTCPPort{}) + RegisterTestCase(NATPreRedirectTCPOutgoing{}) + RegisterTestCase(NATOutRedirectTCPIncoming{}) RegisterTestCase(NATOutRedirectUDPPort{}) RegisterTestCase(NATOutRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) @@ -91,6 +93,56 @@ func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error { return connectTCP(ip, dropPort, sendloopDuration) } +// NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't +// affected by PREROUTING connection tracking. +type NATPreRedirectTCPOutgoing struct{} + +// Name implements TestCase.Name. +func (NATPreRedirectTCPOutgoing) Name() string { + return "NATPreRedirectTCPOutgoing" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectTCPOutgoing) ContainerAction(ip net.IP) error { + // Redirect all incoming TCP traffic to a closed port. + if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return connectTCP(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectTCPOutgoing) LocalAction(ip net.IP) error { + return listenTCP(acceptPort, sendloopDuration) +} + +// NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't +// affected by OUTPUT connection tracking. +type NATOutRedirectTCPIncoming struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectTCPIncoming) Name() string { + return "NATOutRedirectTCPIncoming" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPIncoming) ContainerAction(ip net.IP) error { + // Redirect all outgoing TCP traffic to a closed port. + if err := natTable("-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return listenTCP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectTCPIncoming) LocalAction(ip net.IP) error { + return connectTCP(ip, acceptPort, sendloopDuration) +} + // NATOutRedirectUDPPort tests that packets are redirected to different port. type NATOutRedirectUDPPort struct{} -- cgit v1.2.3 From dd530eeeff09128d4c2428e1d6f24205a29e661e Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 23 Jul 2020 15:44:09 -0700 Subject: iptables: use keyed array literals PiperOrigin-RevId: 322882426 --- pkg/tcpip/stack/iptables.go | 93 ++++++++++++++++++--------------------------- 1 file changed, 37 insertions(+), 56 deletions(-) (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index ca1dda695..cbbae4224 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -58,8 +58,7 @@ const reaperDelay = 5 * time.Second func DefaultTables() *IPTables { return &IPTables{ tables: [numTables]Table{ - // NAT table. - Table{ + natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -68,22 +67,21 @@ func DefaultTables() *IPTables { Rule{Target: ErrorTarget{}}, }, BuiltinChains: [NumHooks]int{ - 0, // Prerouting. - 1, // Input. - HookUnset, // Forward. - 2, // Output. - 3, // Postrouting. + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, }, Underflows: [NumHooks]int{ - 0, // Prerouting. - 1, // Input. - HookUnset, // Forward. - 2, // Output. - 3, // Postrouting. + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, }, }, - // Mangle table. - Table{ + mangleID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -94,15 +92,14 @@ func DefaultTables() *IPTables { Output: 1, }, Underflows: [NumHooks]int{ - 0, // Prerouting. - HookUnset, // Input. - HookUnset, // Forward. - 1, // Output. - HookUnset, // Postrouting. + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, }, }, - // Filter table. - Table{ + filterID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -110,27 +107,25 @@ func DefaultTables() *IPTables { Rule{Target: ErrorTarget{}}, }, BuiltinChains: [NumHooks]int{ - HookUnset, // Prerouting. - Input: 0, // Input. - Forward: 1, // Forward. - Output: 2, // Output. - HookUnset, // Postrouting. + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, Underflows: [NumHooks]int{ - HookUnset, // Prerouting. - 0, // Input. - 1, // Forward. - 2, // Output. - HookUnset, // Postrouting. + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, }, }, priorities: [NumHooks][]tableID{ - []tableID{mangleID, natID}, // Prerouting. - []tableID{natID, filterID}, // Input. - []tableID{}, // Forward. - []tableID{mangleID, natID, filterID}, // Output. - []tableID{}, // Postrouting. + Prerouting: []tableID{mangleID, natID}, + Input: []tableID{natID, filterID}, + Output: []tableID{mangleID, natID, filterID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -145,18 +140,12 @@ func EmptyFilterTable() Table { return Table{ Rules: []Rule{}, BuiltinChains: [NumHooks]int{ - HookUnset, - 0, - 0, - 0, - HookUnset, + Prerouting: HookUnset, + Postrouting: HookUnset, }, Underflows: [NumHooks]int{ - HookUnset, - 0, - 0, - 0, - HookUnset, + Prerouting: HookUnset, + Postrouting: HookUnset, }, } } @@ -167,18 +156,10 @@ func EmptyNATTable() Table { return Table{ Rules: []Rule{}, BuiltinChains: [NumHooks]int{ - 0, - 0, - HookUnset, - 0, - 0, + Forward: HookUnset, }, Underflows: [NumHooks]int{ - 0, - 0, - HookUnset, - 0, - 0, + Forward: HookUnset, }, } } -- cgit v1.2.3 From 2a7b2a61e3ea32129c26eeaa6fab3d81a5d8ad6e Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 11 Jun 2020 20:33:56 -0700 Subject: iptables: support SO_ORIGINAL_DST Envoy (#170) uses this to get the original destination of redirected packets. --- pkg/abi/linux/netfilter.go | 8 +- pkg/abi/linux/socket.go | 4 +- pkg/sentry/socket/netstack/netstack.go | 17 ++++ pkg/sentry/strace/socket.go | 1 + pkg/tcpip/stack/conntrack.go | 26 ++++++ pkg/tcpip/stack/iptables.go | 11 ++- pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/tcp/endpoint.go | 11 +++ test/iptables/BUILD | 1 + test/iptables/iptables_test.go | 8 ++ test/iptables/iptables_unsafe.go | 63 ++++++++++++++ test/iptables/iptables_util.go | 51 ++++++++++- test/iptables/nat.go | 152 ++++++++++++++++++++++++++++++++- 13 files changed, 344 insertions(+), 13 deletions(-) create mode 100644 test/iptables/iptables_unsafe.go (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index a91f9f018..9c27f7bb2 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -59,7 +59,7 @@ var VerdictStrings = map[int32]string{ NF_RETURN: "RETURN", } -// Socket options. These correspond to values in +// Socket options for SOL_SOCKET. These correspond to values in // include/uapi/linux/netfilter_ipv4/ip_tables.h. const ( IPT_BASE_CTL = 64 @@ -74,6 +74,12 @@ const ( IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET ) +// Socket option for SOL_IP. This corresponds to the value in +// include/uapi/linux/netfilter_ipv4.h. +const ( + SO_ORIGINAL_DST = 80 +) + // Name lengths. These correspond to values in // include/uapi/linux/netfilter/x_tables.h. const ( diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index c24a8216e..d6946bb82 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -239,11 +239,13 @@ const SockAddrMax = 128 type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. +// +// +marshal type SockAddrInet struct { Family uint16 Port uint16 Addr InetAddr - Zero [8]uint8 // pad to sizeof(struct sockaddr). + _ [8]uint8 // pad to sizeof(struct sockaddr). } // InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f86e6cd7a..31a168f7e 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1490,6 +1490,10 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ORIGINAL_DST: + // TODO(gvisor.dev/issue/170): ip6tables. + return nil, syserr.ErrInvalidArgument + default: emitUnimplementedEventIPv6(t, name) } @@ -1600,6 +1604,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ORIGINAL_DST: + if outLen < int(binary.Size(linux.SockAddrInet{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet), nil + default: emitUnimplementedEventIP(t, name) } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index c0512de89..b51c4c941 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -521,6 +521,7 @@ var sockOptNames = map[uint64]abi.ValueSet{ linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT", linux.IP_PKTOPTIONS: "IP_PKTOPTIONS", linux.IP_MTU: "IP_MTU", + linux.SO_ORIGINAL_DST: "SO_ORIGINAL_DST", }, linux.SOL_SOCKET: { linux.SO_ERROR: "SO_ERROR", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 559a1c4dd..470c265aa 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -240,7 +240,10 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { if err != nil { return nil, dirOriginal } + return ct.connForTID(tid) +} +func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { bucket := ct.bucket(tid) now := time.Now() @@ -604,3 +607,26 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } + +func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + // Lookup the connection. The reply's original destination + // describes the original address. + tid := tupleID{ + srcAddr: epID.LocalAddress, + srcPort: epID.LocalPort, + dstAddr: epID.RemoteAddress, + dstPort: epID.RemotePort, + transProto: header.TCPProtocolNumber, + netProto: header.IPv4ProtocolNumber, + } + conn, _ := ct.connForTID(tid) + if conn == nil { + // Not a tracked connection. + return "", 0, tcpip.ErrNotConnected + } else if conn.manip == manipNone { + // Unmanipulated connection. + return "", 0, tcpip.ErrInvalidOptionValue + } + + return conn.original.dstAddr, conn.original.dstPort, nil +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index cbbae4224..110ba073d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -218,19 +218,16 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() + defer it.mu.RUnlock() if !it.modified { - it.mu.RUnlock() return true } - it.mu.RUnlock() // Packets are manipulated only if connection and matching // NAT rule exists. shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - it.mu.RLock() - defer it.mu.RUnlock() priorities := it.priorities[hook] for _, tableID := range priorities { // If handlePacket already NATed the packet, we don't need to @@ -418,3 +415,9 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // All the matchers matched, so run the target. return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } + +// OriginalDst returns the original destination of redirected connections. It +// returns an error if the connection doesn't exist or isn't redirected. +func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + return it.connections.originalDst(epID) +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index a634b9b60..45f59b60f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -954,6 +954,10 @@ type DefaultTTLOption uint8 // classic BPF filter on a given endpoint. type SocketDetachFilterOption int +// OriginalDestinationOption is used to get the original destination address +// and port of a redirected packet. +type OriginalDestinationOption FullAddress + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0f7487963..682687ebe 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2017,6 +2017,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.TCPDeferAcceptOption(e.deferAccept) e.UnlockUser() + case *tcpip.OriginalDestinationOption: + ipt := e.stack.IPTables() + addr, port, err := ipt.OriginalDst(e.ID) + if err != nil { + return err + } + *o = tcpip.OriginalDestinationOption{ + Addr: addr, + Port: port, + } + default: return tcpip.ErrUnknownProtocolOption } diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 40b63ebbe..66453772a 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -9,6 +9,7 @@ go_library( "filter_input.go", "filter_output.go", "iptables.go", + "iptables_unsafe.go", "iptables_util.go", "nat.go", ], diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 550b6198a..fda5f694f 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -371,3 +371,11 @@ func TestFilterAddrs(t *testing.T) { } } } + +func TestNATPreOriginalDst(t *testing.T) { + singleTest(t, NATPreOriginalDst{}) +} + +func TestNATOutOriginalDst(t *testing.T) { + singleTest(t, NATOutOriginalDst{}) +} diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go new file mode 100644 index 000000000..bd85a8fea --- /dev/null +++ b/test/iptables/iptables_unsafe.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "syscall" + "unsafe" +) + +type originalDstError struct { + errno syscall.Errno +} + +func (e originalDstError) Error() string { + return fmt.Sprintf("errno (%d) when calling getsockopt(SO_ORIGINAL_DST): %v", int(e.errno), e.errno.Error()) +} + +// SO_ORIGINAL_DST gets the original destination of a redirected packet via +// getsockopt. +const SO_ORIGINAL_DST = 80 + +func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) { + var addr syscall.RawSockaddrInet4 + var addrLen uint32 = syscall.SizeofSockaddrInet4 + if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet4{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) { + var addr syscall.RawSockaddrInet6 + var addrLen uint32 = syscall.SizeofSockaddrInet6 + if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet6{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno { + _, _, errno := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + uintptr(connfd), + level, + SO_ORIGINAL_DST, + uintptr(optval), + uintptr(unsafe.Pointer(optlen)), + 0) + return errno +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index ca80a4b5f..5125fe47b 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -15,6 +15,8 @@ package iptables import ( + "encoding/binary" + "errors" "fmt" "net" "os/exec" @@ -218,17 +220,58 @@ func filterAddrs(addrs []string, ipv6 bool) []string { // getInterfaceName returns the name of the interface other than loopback. func getInterfaceName() (string, bool) { - var ifname string + iface, ok := getNonLoopbackInterface() + if !ok { + return "", false + } + return iface.Name, true +} + +func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) { + iface, ok := getNonLoopbackInterface() + if !ok { + return nil, errors.New("no non-loopback interface found") + } + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + // Get only IPv4 or IPv6 addresses. + ips := make([]net.IP, 0, len(addrs)) + for _, addr := range addrs { + parts := strings.Split(addr.String(), "/") + var ip net.IP + // To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses. + // So we check whether To4() returns nil to test whether the + // address is v4 or v6. + if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil { + ip = net.ParseIP(parts[0]).To16() + } else { + ip = v4 + } + if ip != nil { + ips = append(ips, ip) + } + } + return ips, nil +} + +func getNonLoopbackInterface() (net.Interface, bool) { if interfaces, err := net.Interfaces(); err == nil { for _, intf := range interfaces { if intf.Name != "lo" { - ifname = intf.Name - break + return intf, true } } } + return net.Interface{}, false +} - return ifname, ifname != "" +func htons(x uint16) uint16 { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, x) + return binary.LittleEndian.Uint16(buf) } func localIP(ipv6 bool) string { diff --git a/test/iptables/nat.go b/test/iptables/nat.go index ac0d91bb2..b7fea2527 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -18,12 +18,11 @@ import ( "errors" "fmt" "net" + "syscall" "time" ) -const ( - redirectPort = 42 -) +const redirectPort = 42 func init() { RegisterTestCase(NATPreRedirectUDPPort{}) @@ -42,6 +41,8 @@ func init() { RegisterTestCase(NATOutRedirectInvert{}) RegisterTestCase(NATRedirectRequiresProtocol{}) RegisterTestCase(NATLoopbackSkipsPrerouting{}) + RegisterTestCase(NATPreOriginalDst{}) + RegisterTestCase(NATOutOriginalDst{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. @@ -471,6 +472,151 @@ func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP, ipv6 bool) error { return nil } +// NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of PREROUTING NATted packets. +type NATPreOriginalDst struct{} + +// Name implements TestCase.Name. +func (NATPreOriginalDst) Name() string { + return "NATPreOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "PREROUTING", + "-p", "tcp", + "--destination-port", fmt.Sprintf("%d", dropPort), + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + addrs, err := getInterfaceAddrs(ipv6) + if err != nil { + return err + } + return listenForRedirectedConn(ipv6, addrs) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { + return connectTCP(ip, dropPort, sendloopDuration) +} + +// NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of OUTBOUND NATted packets. +type NATOutOriginalDst struct{} + +// Name implements TestCase.Name. +func (NATOutOriginalDst) Name() string { + return "NATOutOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + connCh := make(chan error) + go func() { + connCh <- connectTCP(ip, dropPort, sendloopDuration) + }() + + if err := listenForRedirectedConn(ipv6, []net.IP{ip}); err != nil { + return err + } + return <-connCh +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { + // The net package doesn't give guarantee access to the connection's + // underlying FD, and thus we cannot call getsockopt. We have to use + // traditional syscalls for SO_ORIGINAL_DST. + + // Create the listening socket, bind, listen, and accept. + family := syscall.AF_INET + if ipv6 { + family = syscall.AF_INET6 + } + sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0) + if err != nil { + return err + } + defer syscall.Close(sockfd) + + var bindAddr syscall.Sockaddr + if ipv6 { + bindAddr = &syscall.SockaddrInet6{ + Port: acceptPort, + Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any + } + } else { + bindAddr = &syscall.SockaddrInet4{ + Port: acceptPort, + Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY + } + } + if err := syscall.Bind(sockfd, bindAddr); err != nil { + return err + } + + if err := syscall.Listen(sockfd, 1); err != nil { + return err + } + + connfd, _, err := syscall.Accept(sockfd) + if err != nil { + return err + } + defer syscall.Close(connfd) + + // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST + // indicates the packet was sent to originalDst:dropPort. + if ipv6 { + got, err := originalDestination6(connfd) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet6{ + Family: syscall.AF_INET6, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To16()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } else { + got, err := originalDestination4(connfd) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet4{ + Family: syscall.AF_INET, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To4()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } +} + // loopbackTests runs an iptables rule and ensures that packets sent to // dest:dropPort are received by localhost:acceptPort. func loopbackTest(ipv6 bool, dest net.IP, args ...string) error { -- cgit v1.2.3 From 47515f475167ffa23267ca0b9d1b39e7907587d6 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Thu, 13 Aug 2020 13:07:03 -0700 Subject: Migrate to PacketHeader API for PacketBuffer. Formerly, when a packet is constructed or parsed, all headers are set by the client code. This almost always involved prepending to pk.Header buffer or trimming pk.Data portion. This is known to prone to bugs, due to the complexity and number of the invariants assumed across netstack to maintain. In the new PacketHeader API, client will call Push()/Consume() method to construct/parse an outgoing/incoming packet. All invariants, such as slicing and trimming, are maintained by the API itself. NewPacketBuffer() is introduced to create new PacketBuffer. Zero value is no longer valid. PacketBuffer now assumes the packet is a concatenation of following portions: * LinkHeader * NetworkHeader * TransportHeader * Data Any of them could be empty, or zero-length. PiperOrigin-RevId: 326507688 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 4 +- pkg/sentry/socket/netfilter/udp_matcher.go | 4 +- pkg/tcpip/buffer/view.go | 10 + pkg/tcpip/link/channel/channel.go | 4 +- pkg/tcpip/link/fdbased/BUILD | 1 + pkg/tcpip/link/fdbased/endpoint.go | 14 +- pkg/tcpip/link/fdbased/endpoint_test.go | 134 ++++--- pkg/tcpip/link/fdbased/mmap.go | 16 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 45 ++- pkg/tcpip/link/loopback/loopback.go | 21 +- pkg/tcpip/link/muxed/injectable_test.go | 25 +- pkg/tcpip/link/nested/nested_test.go | 4 +- pkg/tcpip/link/sharedmem/sharedmem.go | 25 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 135 +++---- pkg/tcpip/link/sharedmem/tx.go | 14 +- pkg/tcpip/link/sniffer/sniffer.go | 19 +- pkg/tcpip/link/tun/device.go | 23 +- pkg/tcpip/link/waitable/waitable_test.go | 12 +- pkg/tcpip/network/arp/arp.go | 26 +- pkg/tcpip/network/arp/arp_test.go | 8 +- pkg/tcpip/network/ip_test.go | 80 +++-- pkg/tcpip/network/ipv4/icmp.go | 38 +- pkg/tcpip/network/ipv4/ipv4.go | 188 +++++----- pkg/tcpip/network/ipv4/ipv4_test.go | 72 ++-- pkg/tcpip/network/ipv6/icmp.go | 55 ++- pkg/tcpip/network/ipv6/icmp_test.go | 46 +-- pkg/tcpip/network/ipv6/ipv6.go | 47 +-- pkg/tcpip/network/ipv6/ipv6_test.go | 16 +- pkg/tcpip/network/ipv6/ndp_test.go | 29 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/conntrack.go | 31 +- pkg/tcpip/stack/forwarder_test.go | 65 ++-- pkg/tcpip/stack/headertype_string.go | 39 ++ pkg/tcpip/stack/iptables.go | 2 +- pkg/tcpip/stack/iptables_targets.go | 9 +- pkg/tcpip/stack/ndp.go | 32 +- pkg/tcpip/stack/ndp_test.go | 22 +- pkg/tcpip/stack/nic.go | 40 +-- pkg/tcpip/stack/nic_test.go | 4 +- pkg/tcpip/stack/packet_buffer.go | 260 ++++++++++++-- pkg/tcpip/stack/packet_buffer_test.go | 397 +++++++++++++++++++++ pkg/tcpip/stack/route.go | 5 +- pkg/tcpip/stack/stack_test.go | 61 ++-- pkg/tcpip/stack/transport_demuxer_test.go | 14 +- pkg/tcpip/stack/transport_test.go | 54 ++- .../tests/integration/multicast_broadcast_test.go | 18 +- pkg/tcpip/transport/icmp/endpoint.go | 39 +- pkg/tcpip/transport/packet/endpoint.go | 35 +- pkg/tcpip/transport/raw/endpoint.go | 30 +- pkg/tcpip/transport/tcp/connect.go | 30 +- pkg/tcpip/transport/tcp/protocol.go | 17 +- pkg/tcpip/transport/tcp/segment.go | 2 +- pkg/tcpip/transport/tcp/testing/context/context.go | 28 +- pkg/tcpip/transport/udp/endpoint.go | 26 +- pkg/tcpip/transport/udp/protocol.go | 57 ++- pkg/tcpip/transport/udp/udp_test.go | 58 ++- 56 files changed, 1568 insertions(+), 924 deletions(-) create mode 100644 pkg/tcpip/stack/headertype_string.go create mode 100644 pkg/tcpip/stack/packet_buffer_test.go (limited to 'pkg/tcpip/stack/iptables.go') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 4f98ee2d5..0bfd6c1f4 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,7 +97,7 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) if netHeader.TransportProtocol() != header.TCPProtocolNumber { return false, false @@ -111,7 +111,7 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - tcpHeader := header.TCP(pkt.TransportHeader) + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { // There's no valid TCP header here, so we drop the packet immediately. return false, true diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3f20fc891..7ed05461d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,7 +94,7 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. @@ -110,7 +110,7 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) if len(udpHeader) < header.UDPMinimumSize { // There's no valid UDP header here, so we drop the packet immediately. return false, true diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 9a3c5d6c3..ea0c5413d 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -65,6 +65,16 @@ func (v View) ToVectorisedView() VectorisedView { return NewVectorisedView(len(v), []View{v}) } +// IsEmpty returns whether v is of length zero. +func (v View) IsEmpty() bool { + return len(v) == 0 +} + +// Size returns the length of v. +func (v View) Size() int { + return len(v) +} + // VectorisedView is a vectorised version of View using non contiguous memory. // It supports all the convenience methods supported by View. // diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index e12a5929b..c95aef63c 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -274,7 +274,9 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: &stack.PacketBuffer{Data: vv}, + Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }), Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 507b44abc..10072eac1 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -37,5 +37,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index c18bb91fb..975309fc8 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -390,8 +390,7 @@ const ( func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ DstAddr: remote, Type: protocol, @@ -420,7 +419,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if gso != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen @@ -443,11 +442,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne builder.Add(vnetHdrBuf) } - builder.Add(pkt.Header.View()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Views() { builder.Add(v) } - return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } @@ -463,7 +460,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if pkt.GSOOptions != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen @@ -486,8 +483,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc var builder iovec.Builder builder.Add(vnetHdrBuf) - builder.Add(pkt.Header.View()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Views() { builder.Add(v) } iovecs := builder.Build() diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 7b995b85a..709f829c8 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -26,6 +26,7 @@ import ( "time" "unsafe" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,9 +44,36 @@ const ( ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents *stack.PacketBuffer + Raddr tcpip.LinkAddress + Proto tcpip.NetworkProtocolNumber + Contents *stack.PacketBuffer +} + +type packetContents struct { + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View + Data buffer.View +} + +func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { + t.Helper() + if diff := cmp.Diff( + want, got, + cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { + if pk == nil { + return nil + } + return &packetContents{ + LinkHeader: pk.LinkHeader().View(), + NetworkHeader: pk.NetworkHeader().View(), + TransportHeader: pk.TransportHeader().View(), + Data: pk.Data.ToView(), + } + }), + ); diff != "" { + t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) + } } type context struct { @@ -159,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u RemoteLinkAddress: raddr, } - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) + // Build payload. + payload := buffer.NewView(plen) + if _, err := rand.Read(payload); err != nil { + t.Fatalf("rand.Read(payload): %s", err) } - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) + // Build packet buffer. + const netHdrLen = 100 + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, + Data: payload.ToVectorisedView(), + }) + pkt.Hash = hash + + // Build header. + b := pkt.NetworkHeader().Push(netHdrLen) + if _, err := rand.Read(b); err != nil { + t.Fatalf("rand.Read(b): %s", err) } - want := append(hdr.View(), payload...) + + // Write. + want := append(append(buffer.View(nil), b...), payload...) var gso *stack.GSO if gsoMaxSize != 0 { gso = &stack.GSO{ @@ -183,11 +220,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - Hash: hash, - }); err != nil { + if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -296,13 +329,14 @@ func TestPreserveSrcAddress(t *testing.T) { LocalLinkAddress: baddr, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.VectorisedView{}, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.VectorisedView{}, + }) + if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -331,24 +365,25 @@ func TestDeliverPacket(t *testing.T) { defer c.cleanup() // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) + all := make([]byte, plen) + if _, err := rand.Read(all); err != nil { + t.Fatalf("rand.Read(all): %s", err) } - - var hdr header.Ethernet - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr = make(header.Ethernet, header.EthernetMinimumSize) + // Make it look like an IPv4 packet. + all[0] = 0x40 + + wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.NewViewFromBytes(all).ToVectorisedView(), + }) + if eth { + hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, Type: proto, }) - all = append(hdr, b...) + all = append(hdr, all...) } // Write packet via the file descriptor. @@ -360,24 +395,15 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: &stack.PacketBuffer{ - Data: buffer.View(b).ToVectorisedView(), - LinkHeader: buffer.View(hdr), - }, + Raddr: raddr, + Proto: proto, + Contents: wantPkt, } if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - // want.contents.Data will be a single - // view, so make pi do the same for the - // DeepEqual check. - pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + want.Proto = header.IPv4ProtocolNumber + want.Raddr = "" } + checkPacketInfoEqual(t, pi, want) case <-time.After(10 * time.Second): t.Fatalf("Timed out waiting for packet") } @@ -572,8 +598,8 @@ func TestDispatchPacketFormat(t *testing.T) { t.Fatalf("len(sink.pkts) = %d, want %d", got, want) } pkt := sink.pkts[0] - if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want { - t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want) + if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { + t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) } if got, want := pkt.Data.Size(), 4; got != want { t.Errorf("pkt.Data.Size() = %d, want %d", got, want) diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 2dfd29aa9..c475dda20 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -18,6 +18,7 @@ package fdbased import ( "encoding/binary" + "fmt" "syscall" "golang.org/x/sys/unix" @@ -170,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(pkt) + eth := header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -190,10 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } } - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ - Data: buffer.View(pkt).ToVectorisedView(), - LinkHeader: buffer.View(eth), + pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(pkt).ToVectorisedView(), }) + if d.e.hdrSize > 0 { + if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok { + panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize)) + } + } + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index d8f2504b3..8c3ca86d6 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.allocateViews(BufConfig) n, err := rawfile.BlockingReadv(d.fd, d.iovecs) - if err != nil { + if n == 0 || err != nil { return false, err } if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { // isn't used and it isn't in a view. n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(n, BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,13 +143,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(n, BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. @@ -268,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(k, int(n), BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -295,12 +298,6 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(k, int(n), BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 781cdd317..38aa694e4 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -77,16 +77,16 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) + // Construct data as the unparsed portion for the loopback packet. + data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt) return nil } @@ -98,18 +98,17 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) if !ok { // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } linkHeader := header.Ethernet(hdr) - vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ - Data: vv, - LinkHeader: buffer.View(linkHeader), - }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) return nil } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 0744f66d6..3e4afcdad 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -46,14 +46,14 @@ func TestInjectableEndpointRawDispatch(t *testing.T) { func TestInjectableEndpointDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) + pkt.TransportHeader().Push(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -67,13 +67,14 @@ func TestInjectableEndpointDispatch(t *testing.T) { func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewView(0).ToVectorisedView(), + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewView(0).ToVectorisedView(), }) + pkt.TransportHeader().Push(1)[0] = 0xFA + packetRoute := stack.Route{RemoteAddress: dstIP} + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go index 7d9249c1c..c1f9d308c 100644 --- a/pkg/tcpip/link/nested/nested_test.go +++ b/pkg/tcpip/link/nested/nested_test.go @@ -87,7 +87,7 @@ func TestNestedLinkEndpoint(t *testing.T) { t.Error("After attach, nestedEP.IsAttached() = false, want = true") } - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 1 { t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) } @@ -101,7 +101,7 @@ func TestNestedLinkEndpoint(t *testing.T) { } disp.count = 0 - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 0 { t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 507c76b76..7fb8a6c49 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -186,8 +186,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // AddHeader implements stack.LinkEndpoint.AddHeader. func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ DstAddr: remote, Type: protocol, @@ -207,10 +206,10 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) - v := pkt.Data.ToView() + views := pkt.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(pkt.Header.View(), v) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -227,10 +226,10 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - v := vv.ToView() + views := vv.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(v, buffer.View{}) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -276,16 +275,18 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { rxb[i].Size = e.bufferSize } - if n < header.EthernetMinimumSize { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { continue } + eth := header.Ethernet(hdr) // Send packet up the stack. - eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ - Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), - LinkHeader: buffer.View(eth), - }) + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 8f3cd9449..22d5c97f1 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -266,21 +266,23 @@ func TestSimpleSend(t *testing.T) { for iters := 1000; iters > 0; iters-- { func() { + hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000) + // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) + hdrBuf := buffer.NewView(hdrLen) randomFill(hdrBuf) - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) + data := buffer.NewView(dataLen) + randomFill(data) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()), + Data: data.ToVectorisedView(), + }) + copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -317,7 +319,7 @@ func TestSimpleSend(t *testing.T) { // Compare contents skipping the ethernet header added by the // endpoint. - merged := append(hdrBuf, buf...) + merged := append(hdrBuf, data...) if uint32(len(contents)) < pi.Size { t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) } @@ -344,14 +346,14 @@ func TestPreserveSrcAddressInSend(t *testing.T) { LocalLinkAddress: newLocalLinkAddress, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + ReserveHeaderBytes: header.EthernetMinimumSize, + }) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -403,12 +405,12 @@ func TestFillTxQueue(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -422,11 +424,11 @@ func TestFillTxQueue(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -450,11 +452,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -473,11 +475,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -491,11 +493,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -517,11 +519,11 @@ func TestFillTxMemory(t *testing.T) { // we fill the memory. ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -536,11 +538,11 @@ func TestFillTxMemory(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), }) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -564,11 +566,11 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Each packet is uses up one buffer, so write as many as possible // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -579,23 +581,22 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write a two-buffer packet. It must fail. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: uu, - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buffer.NewView(bufferSize).ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } // Attempt to write the one-buffer packet again. It must succeed. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index 6b8d7859d..44f421c2d 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "syscall" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -76,9 +77,9 @@ func (t *tx) cleanup() { syscall.Munmap(t.data) } -// transmit sends a packet made up of up to two buffers. Returns a boolean that -// specifies whether the packet was successfully transmitted. -func (t *tx) transmit(a, b []byte) bool { +// transmit sends a packet made of bufs. Returns a boolean that specifies +// whether the packet was successfully transmitted. +func (t *tx) transmit(bufs ...buffer.View) bool { // Pull completions from the tx queue and add their buffers back to the // pool so that we can reuse them. for { @@ -93,7 +94,10 @@ func (t *tx) transmit(a, b []byte) bool { } bSize := t.bufs.entrySize - total := uint32(len(a) + len(b)) + total := uint32(0) + for _, data := range bufs { + total += uint32(len(data)) + } bufCount := (total + bSize - 1) / bSize // Allocate enough buffers to hold all the data. @@ -115,7 +119,7 @@ func (t *tx) transmit(a, b []byte) bool { // Copy data into allocated buffers. nBuf := buf var dBuf []byte - for _, data := range [][]byte{a, b} { + for _, data := range bufs { for len(data) > 0 { if len(dBuf) == 0 { dBuf = t.data[nBuf.Offset:][:nBuf.Size] diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 509076643..4fb127978 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -134,7 +134,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { - totalLength := pkt.Header.UsedLength() + pkt.Data.Size() + totalLength := pkt.Size() length := totalLength if max := int(e.maxPCAPLen); length > max { length = max @@ -155,12 +155,11 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw length -= n } } - write(pkt.Header.View()) - for _, view := range pkt.Data.Views() { + for _, v := range pkt.Views() { if length == 0 { break } - write(view) + write(v) } } } @@ -185,9 +184,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ + e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) return e.Endpoint.WriteRawPacket(vv) } @@ -201,12 +200,8 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var fragmentOffset uint16 var moreFragments bool - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) + // Examine the packet using a new VV. Backing storage must not be written. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) switch protocol { case header.IPv4ProtocolNumber: diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 22b0a12bd..3b1510a33 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -215,12 +215,11 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := &stack.PacketBuffer{ - Data: buffer.View(data).ToVectorisedView(), - } - if ethHdr != nil { - pkt.LinkHeader = buffer.View(ethHdr) - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: len(ethHdr), + Data: buffer.View(data).ToVectorisedView(), + }) + copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr) endpoint.InjectLinkAddr(protocol, remote, pkt) return dataLen, nil } @@ -265,21 +264,22 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { return nil, false } // Ethernet header (TAP only). if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. - if info.Pkt.LinkHeader == nil { + if info.Pkt.LinkHeader().View().IsEmpty() { d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } - vv.AppendView(info.Pkt.LinkHeader) + vv.AppendView(info.Pkt.LinkHeader().View()) } // Append upper headers. - vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + vv.AppendView(info.Pkt.NetworkHeader().View()) + vv.AppendView(info.Pkt.TransportHeader().View()) // Append data payload. vv.Append(info.Pkt.Data) @@ -361,8 +361,7 @@ func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip. if !e.isTap { return } - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr := &header.EthernetFields{ SrcAddr: local, DstAddr: remote, diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index c448a888f..94827fc56 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -104,21 +104,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -135,21 +135,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 31a242482..1ad788a17 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,7 +99,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.ARP(pkt.NetworkHeader) + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return } @@ -110,17 +110,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } - hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - packet := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + }) + packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -168,17 +168,17 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetBroadcastAddress } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) - h := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + }) + h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) copy(h.HardwareAddressSender(), linkEP.LinkAddress()) copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -210,12 +210,10 @@ func (*protocol) Wait() {} // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.ARPSize) + _, ok = pkt.NetworkHeader().Consume(header.ARPSize) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(header.ARPSize) return 0, false, true } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index a35a64a0f..c2c3e6891 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -106,9 +106,9 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { @@ -118,9 +118,9 @@ func TestDirectRequest(t *testing.T) { if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pi.Pkt.Header.View()) + rep := header.ARP(pi.Pkt.NetworkHeader().View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 615bae648..e6768258a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -156,13 +156,13 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(pkt.Header.View()) + h := header.IPv4(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(pkt.Header.View()) + h := header.IPv6(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() @@ -243,8 +243,11 @@ func TestIPv4Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -260,10 +263,7 @@ func TestIPv4Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -303,9 +303,13 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -455,17 +459,25 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag1.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag2.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -485,8 +497,11 @@ func TestIPv6Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -502,10 +517,7 @@ func TestIPv6Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -545,9 +557,13 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -673,11 +689,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // becomes Data. func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { v := view[:len(view)-trunc] - if len(v) < netHdrLen { - return &stack.PacketBuffer{Data: v.ToVectorisedView()} - } - return &stack.PacketBuffer{ - NetworkHeader: v[:netHdrLen], - Data: v[netHdrLen:].ToVectorisedView(), - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + _, _ = pkt.NetworkHeader().Consume(netHdrLen) + return pkt } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 94803a359..067d770f3 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -89,12 +89,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { return } + // Make a copy of data before pkt gets sent to raw socket. + // DeliverTransportPacket will take ownership of pkt. + replyData := pkt.Data.Clone(nil) + replyData.TrimFront(header.ICMPv4MinimumSize) + // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ - Data: pkt.Data.Clone(nil), - NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), - }) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) remoteLinkAddr := r.RemoteLinkAddress @@ -116,24 +118,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // Use the remote link address from the incoming packet. r.ResolveWith(remoteLinkAddr) - vv := pkt.Data.Clone(nil) - vv.TrimFront(header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv4EchoReply) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) + // Prepare a reply packet. + icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize) + copy(icmpHdr, h) + icmpHdr.SetType(header.ICMPv4EchoReply) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0))) + dataVV := buffer.View(icmpHdr).ToVectorisedView() + dataVV.Append(replyData) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: dataVV, + }) + + // Send out the reply packet. sent := stats.ICMP.V4PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: vv, - TransportHeader: buffer.View(pkt), - }); err != nil { + }, replyPkt); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9ff27a363..3cd48ceb3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -21,7 +21,6 @@ package ipv4 import ( - "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -127,14 +126,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in pkt.Header but does not -// assume that only the IP header is in pkt.Header. It assumes that the input -// packet's stated length matches the length of the header+payload. mtu -// includes the IP header and options. This does not support the DontFragment -// IP flag. +// write. It assumes that the IP header is already present in pkt.NetworkHeader. +// pkt.TransportHeader may be set. mtu includes the IP header and options. This +// does not support the DontFragment IP flag. func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.Header.View()) + ip := header.IPv4(pkt.NetworkHeader().View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -148,88 +145,84 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := pkt.Header.AvailableLength() + + // Keep the length reserved for link-layer, we need to create fragments with + // the same reserved length. + reservedForLink := pkt.AvailableHeaderBytes() + + // Destroy the packet, pull all payloads out for fragmentation. + transHeader, data := pkt.TransportHeader().View(), pkt.Data + + // Where possible, the first fragment that is sent has the same + // number of bytes reserved for header as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. + transFitsFirst := len(transHeader) <= innerMTU + for i := 0; i < n; i++ { - // Where possible, the first fragment that is sent has the same - // pkt.Header.UsedLength() as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - h := ip - if i > 0 { - pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) - copy(h, ip[:ip.HeaderLength()]) + reserve := reservedForLink + int(ip.HeaderLength()) + if i == 0 && transFitsFirst { + // Reserve for transport header if it's going to be put in the first + // fragment. + reserve += len(transHeader) + } + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: reserve, + }) + fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + + // Copy data for the fragment. + avail := innerMTU + + if n := len(transHeader); n > 0 { + if n > avail { + n = avail + } + if i == 0 && transFitsFirst { + copy(fragPkt.TransportHeader().Push(n), transHeader) + } else { + fragPkt.Data.AppendView(transHeader[:n:n]) + } + transHeader = transHeader[n:] + avail -= n } + + if avail > 0 { + n := data.Size() + if n > avail { + n = avail + } + data.ReadToVV(&fragPkt.Data, n) + avail -= n + } + + copied := uint16(innerMTU - avail) + + // Set lengths in header and calculate checksum. + h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) + copy(h, ip) if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + copied) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) - offset += uint16(innerMTU) - if i > 0 { - newPayload := pkt.Data.Clone(nil) - newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayload.Size()) - continue - } - // Special handling for the first fragment because it comes - // from the header. - if outerMTU >= pkt.Header.UsedLength() { - // This fragment can fit all of pkt.Header and possibly - // some of pkt.Data, too. - newPayload := pkt.Data.Clone(nil) - newPayloadLength := outerMTU - pkt.Header.UsedLength() - newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayloadLength) - } else { - // The fragment is too small to fit all of pkt.Header. - startOfHdr := pkt.Header - startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) - emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of pkt.Header into the pkt.Data - // that remains to be sent. - restOfHdr := pkt.Header.View()[outerMTU:] - tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(pkt.Data) - pkt.Data = tmp + offset += copied + + // Send out the fragment. + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + return err } + r.Stats().IP.PacketsSent.Increment() } return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payloadSize) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + length := uint16(pkt.Size()) // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. @@ -245,14 +238,12 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - return ip + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + e.addIPHeader(r, pkt, params) // iptables filtering. All packets that reach here are locally // generated. @@ -269,7 +260,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // only NATted packets, but removing this check short circuits broadcasts // before they are sent out to other hosts. if pkt.NatDone { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) @@ -286,7 +277,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { @@ -306,9 +297,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + e.addIPHeader(r, pkt, params) pkt = pkt.Next() } @@ -333,7 +322,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() @@ -402,17 +391,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu r.Stats().IP.PacketsSent.Increment() - ip = ip[:ip.HeaderLength()] - pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) - pkt.Data.TrimFront(int(ip.HeaderLength())) return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv4(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + h := header.IPv4(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -426,7 +412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } if h.More() || h.FragmentOffset() != 0 { - if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { + if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -470,7 +456,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - pkt.NetworkHeader.CapLength(int(h.HeaderLength())) e.handleICMP(r, pkt) return } @@ -560,14 +545,19 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } ipHdr := header.IPv4(hdr) - // If there are options, pull those into hdr as well. - if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(headerLen) - if !ok { - panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) - } - ipHdr = header.IPv4(hdr) + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return 0, false, false } + ipHdr = header.IPv4(hdr) // If this is a fragment, don't bother parsing the transport header. parseTransportHeader := true @@ -576,8 +566,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) return ipHdr.TransportProtocol(), parseTransportHeader, true } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 63e2c36c2..afd3ac06d 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,6 +17,7 @@ package ipv4_test import ( "bytes" "encoding/hex" + "fmt" "math/rand" "testing" @@ -91,15 +92,11 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much +// makeRandPkt generates a randomize packet. hdrLength indicates how much // data should already be in the header before WritePacket. extraLength // indicates how much extra space should be in the header. The payload is made // from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - +func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer { var views []buffer.View totalLength := 0 for _, s := range viewSizes { @@ -108,8 +105,16 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. views = append(views, newView) totalLength += s } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLength + extraLength, + Data: buffer.NewVectorisedView(totalLength, views), + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt } // comparePayloads compared the contents of all the packets against the contents @@ -117,9 +122,9 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Data.ToView()...) + source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) + vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) + source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -132,8 +137,7 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI var reassembledPayload []byte for i, packet := range packets { // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Data) + allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -144,10 +148,17 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + if i == 0 { + got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() + // sourcePacketInfo does not have NetworkHeader added, simulate one. + want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() + // Check that it kept the transport header in packet.TransportHeader if + // it fits in the first fragment. + if want < int(mtu) && got != want { + t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + } } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { + if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { @@ -284,22 +295,14 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := &stack.PacketBuffer{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - } + pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + source := pkt.Clone() c := buildContext(t, nil, ft.mtu) err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) if err != nil { t.Errorf("err got %v, want %v", err, nil) } @@ -344,16 +347,13 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) + pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -472,9 +472,9 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), - }) + })) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { @@ -859,9 +859,9 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ded91d83a..39ae19295 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -83,7 +83,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } h := header.ICMPv6(v) - iph := header.IPv6(pkt.NetworkHeader) + iph := header.IPv6(pkt.NetworkHeader().View()) // Validate ICMPv6 checksum before processing the packet. // @@ -276,8 +276,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme optsSerializer := header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), + }) + packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(solicited) @@ -293,9 +295,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.Dropped.Increment() return } @@ -384,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { received.Invalid.Increment() return @@ -409,16 +409,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Use the link address from the source of the original packet. r.ResolveWith(remoteLinkAddr) - pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, + Data: pkt.Data, + }) + packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: pkt.Data, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -539,17 +538,19 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - copy(pkt[icmpV6OptOffset-len(addr):], addr) - pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - length := uint16(hdr.UsedLength()) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, + }) + icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + icmpHdr.SetType(header.ICMPv6NeighborSolicit) + copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) + icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr + icmpHdr[icmpV6LengthOffset] = 1 + copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -559,9 +560,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index f86aaed1d..2a2f7de01 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -183,7 +183,11 @@ func TestICMPCounts(t *testing.T) { } handleIPv6Payload := func(icmp header.ICMPv6) { - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -191,10 +195,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: buffer.View(icmp).ToVectorisedView(), - }) + ep.HandlePacket(&r, pkt) } for _, typ := range types { @@ -323,12 +324,10 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. pi, _ := args.src.ReadContext(context.Background()) { - views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} - size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ - Data: vv, + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), }) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) } if pi.Proto != ProtocolNumber { @@ -340,7 +339,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } - ipv6 := header.IPv6(pi.Pkt.Header.View()) + // Pull the full payload since network header. Needed for header.IPv6 to + // extract its payload. + ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -558,9 +559,10 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -719,12 +721,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { icmpSize := size + payloadSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) + icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) + icmpHdr.SetType(typ) + payloadFn(icmpHdr.Payload()) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -735,9 +737,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -895,14 +898,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + icmpHdr := header.ICMPv6(hdr.Prepend(size)) + icmpHdr.SetType(typ) payload := buffer.NewView(payloadSize) payloadFn(payload) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -913,9 +916,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index d7d7fc611..0ade655b2 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -99,9 +99,9 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { - length := uint16(hdr.UsedLength() + payloadSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(params.Protocol), @@ -110,26 +110,20 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - return ip + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + e.addIPHeader(r, pkt, params) if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // The inbound path expects an unparsed packet. + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) loopedR.Release() } @@ -151,9 +145,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pb := pkts.Front(); pb != nil; pb = pb.Next() { - ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) - pb.NetworkHeader = buffer.View(ip) - pb.NetworkProtocolNumber = header.IPv6ProtocolNumber + e.addIPHeader(r, pb, params) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -171,8 +163,8 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv6(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + h := header.IPv6(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -181,8 +173,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. - vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() - vv.AppendView(pkt.TransportHeader) + vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader().View()) vv.Append(pkt.Data) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false @@ -410,7 +402,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(len(pkt.TransportHeader)) + extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -581,17 +573,14 @@ traverseExtensions: } } - // Put the IPv6 header with extensions in pkt.NetworkHeader. - hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) if !ok { panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) } ipHdr = header.IPv6(hdr) - - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber return nextHdr, foundNext, true } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 3d65814de..081afb051 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,9 +65,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().ICMP.V6PacketsReceived @@ -123,9 +123,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stat := s.Stats().UDP.PacketsReceived @@ -637,9 +637,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().UDP.PacketsReceived @@ -1469,9 +1469,9 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 64239ce9a..fe159b24f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,9 +136,9 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -380,9 +380,9 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) if test.nsInvalid { if got := invalid.Value(); got != 1 { @@ -410,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.naSrc), checker.DstAddr(test.naDst), checker.TTL(header.NDPHopLimit), @@ -497,9 +497,9 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -560,7 +560,11 @@ func TestNDPValidation(t *testing.T) { nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + Data: payload.ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, @@ -571,10 +575,7 @@ func TestNDPValidation(t *testing.T) { if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - ep.HandlePacket(r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: payload.ToVectorisedView(), - }) + ep.HandlePacket(r, pkt) } var tllData [header.NDPLinkLayerAddressSize]byte @@ -885,9 +886,9 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) if got := rxRA.Value(); got != 1 { t.Fatalf("got rxRA = %d, want = 1", got) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bfc7a0c7c..900938dd1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -57,6 +57,7 @@ go_library( "conntrack.go", "dhcpv6configurationfromndpra_string.go", "forwarder.go", + "headertype_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -143,6 +144,7 @@ go_test( "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", + "packet_buffer_test.go", ], library = ":stack", deps = [ diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 470c265aa..7dd344b4f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -199,12 +199,12 @@ type bucket struct { func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader) - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol } @@ -344,8 +344,8 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -377,8 +377,8 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -396,8 +396,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -423,7 +422,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return false } @@ -433,8 +432,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou return true } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return false } @@ -455,7 +454,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) return false } @@ -474,7 +473,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return } @@ -486,7 +485,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) ct.insertConn(conn) } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c962693f5..944f622fd 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -75,7 +75,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -97,7 +97,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) b[dstAddrOffset] = r.RemoteAddress[0] b[srcAddrOffset] = f.id.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) @@ -144,13 +144,11 @@ func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Add } func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = netHeader - pkt.Data.TrimFront(fwdTestNetHeaderLen) - return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true + return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -290,7 +288,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: &PacketBuffer{Data: vv}, + Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), } select { @@ -382,9 +380,9 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -419,9 +417,9 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -450,9 +448,9 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) select { case <-ep2.C: @@ -480,17 +478,17 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -500,8 +498,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -529,9 +527,9 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < 2; i++ { @@ -543,8 +541,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -575,9 +573,9 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < maxPendingPacketsPerResolution; i++ { @@ -589,13 +587,14 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) } - seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). - if !ok { - t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) } + seqNumBuf := b[fwdTestNetHeaderLen:] // The first 5 packets should not be forwarded so the sequence number should // start with 5. @@ -632,9 +631,9 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < maxPendingResolutions; i++ { @@ -648,8 +647,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go new file mode 100644 index 000000000..5efddfaaf --- /dev/null +++ b/pkg/tcpip/stack/headertype_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type headerType ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[linkHeader-0] + _ = x[networkHeader-1] + _ = x[transportHeader-2] + _ = x[numHeaderType-3] +} + +const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType" + +var _headerType_index = [...]uint8{0, 10, 23, 38, 51} + +func (i headerType) String() string { + if i < 0 || i >= headerType(len(_headerType_index)-1) { + return "headerType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _headerType_name[_headerType_index[i]:_headerType_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 110ba073d..c37da814f 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -394,7 +394,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index dc88033c7..5f1b2af64 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -99,7 +99,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso } // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { return RuleDrop, 0 } @@ -118,17 +118,16 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) switch protocol := netHeader.TransportProtocol(); protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) udpHeader.SetDestinationPort(rt.MinPort) // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 5174e639c..93567806b 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -746,12 +746,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmpData.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) ns.SetTargetAddress(addr) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, @@ -759,7 +763,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() return err @@ -1897,12 +1901,16 @@ func (ndp *ndpState) startSolicitingRouters() { } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) - pkt := header.ICMPv6(hdr.Prepend(payloadSize)) - pkt.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(payloadSize)) + icmpData.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(icmpData.NDPPayload()) rs.Options().Serialize(optsSerializer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, @@ -1910,7 +1918,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5d286ccbc..21bf53010 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -541,7 +541,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), @@ -550,8 +550,8 @@ func TestDADResolve(t *testing.T) { checker.NDPNSOptions(nil), )) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } }) @@ -667,9 +667,10 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(header.IPv6ProtocolNumber, pkt) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -1024,7 +1025,9 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) } // raBufWithOpts returns a valid NDP Router Advertisement with options. @@ -5134,16 +5137,15 @@ func TestRouterSolicitation(t *testing.T) { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } - checker.IPv6(t, - p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } waitForNothing := func(timeout time.Duration) { @@ -5288,7 +5290,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index eaaf756cd..2315ea5b9 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1299,7 +1299,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader) + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1401,24 +1401,19 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this - // by having lower layers explicity write each header instead of just - // pkt.Header. - // pkt may have set its NetworkHeader and TransportHeader. If we're - // forwarding, we'll have to copy them into pkt.Header. - pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) - if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) - } - if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) - } + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } @@ -1443,34 +1438,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. - if pkt.TransportHeader == nil { + if pkt.TransportHeader().View().IsEmpty() { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - pkt.TransportHeader = transHeader - pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) } } - if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { + if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index a70792b50..0870c8d9c 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -311,7 +311,9 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ + Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), + })) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9e871f968..17b8beebb 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,16 +14,43 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) +type headerType int + +const ( + linkHeader headerType = iota + networkHeader + transportHeader + numHeaderType +) + +// PacketBufferOptions specifies options for PacketBuffer creation. +type PacketBufferOptions struct { + // ReserveHeaderBytes is the number of bytes to reserve for headers. Total + // number of bytes pushed onto the headers must not exceed this value. + ReserveHeaderBytes int + + // Data is the initial unparsed data for the new packet. If set, it will be + // owned by the new packet. + Data buffer.VectorisedView +} + // A PacketBuffer contains all the data of a network packet. // // As a PacketBuffer traverses up the stack, it may be necessary to pass it to -// multiple endpoints. Clone() should be called in such cases so that -// modifications to the Data field do not affect other copies. +// multiple endpoints. +// +// The whole packet is expected to be a series of bytes in the following order: +// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be +// empty. Use of PacketBuffer in any other order is unsupported. +// +// PacketBuffer must be created with NewPacketBuffer. type PacketBuffer struct { _ sync.NoCopy @@ -31,36 +58,27 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. For inbound packets, it also - // holds the headers, which are consumed as the packet moves up the - // stack. Headers are guaranteed not to be split across views. + // Data holds the payload of the packet. + // + // For inbound packets, Data is initially the whole packet. Then gets moved to + // headers via PacketHeader.Consume, when the packet is being parsed. // - // The bytes backing Data are immutable, but Data itself may be trimmed - // or otherwise modified. + // For outbound packets, Data is the innermost layer, defined by the protocol. + // Headers are pushed in front of it via PacketHeader.Push. + // + // The bytes backing Data are immutable, a.k.a. users shouldn't write to its + // backing storage. Data buffer.VectorisedView - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. - Header buffer.Prependable + // headers stores metadata about each header. + headers [numHeaderType]headerInfo - // These fields are used by both inbound and outbound packets. They - // typically overlap with the Data and Header fields. + // header is the internal storage for outbound packets. Headers will be pushed + // (prepended) on this storage as the packet is being constructed. // - // The bytes backing these views are immutable. Each field may be nil - // if either it has not been set yet or no such header exists (e.g. - // packets sent via loopback may not have a link header). - // - // These fields may be Views into other slices (either Data or Header). - // SR dosen't support this, so deep copies are necessary in some cases. - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View + // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and + // data are held in the same underlying buffer storage. + header buffer.Prependable // NetworkProtocol is only valid when NetworkHeader is set. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol @@ -89,20 +107,137 @@ type PacketBuffer struct { PktType tcpip.PacketType } -// Clone makes a copy of pk. It clones the Data field, which creates a new -// VectorisedView but does not deep copy the underlying bytes. -// -// Clone also does not deep copy any of its other fields. +// NewPacketBuffer creates a new PacketBuffer with opts. +func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { + pk := &PacketBuffer{ + Data: opts.Data, + } + if opts.ReserveHeaderBytes != 0 { + pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + } + return pk +} + +// ReservedHeaderBytes returns the number of bytes initially reserved for +// headers. +func (pk *PacketBuffer) ReservedHeaderBytes() int { + return pk.header.UsedLength() + pk.header.AvailableLength() +} + +// AvailableHeaderBytes returns the number of bytes currently available for +// headers. This is relevant to PacketHeader.Push method only. +func (pk *PacketBuffer) AvailableHeaderBytes() int { + return pk.header.AvailableLength() +} + +// LinkHeader returns the handle to link-layer header. +func (pk *PacketBuffer) LinkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: linkHeader, + } +} + +// NetworkHeader returns the handle to network-layer header. +func (pk *PacketBuffer) NetworkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: networkHeader, + } +} + +// TransportHeader returns the handle to transport-layer header. +func (pk *PacketBuffer) TransportHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: transportHeader, + } +} + +// HeaderSize returns the total size of all headers in bytes. +func (pk *PacketBuffer) HeaderSize() int { + // Note for inbound packets (Consume called), headers are not stored in + // pk.header. Thus, calculation of size of each header is needed. + var size int + for i := range pk.headers { + size += len(pk.headers[i].buf) + } + return size +} + +// Size returns the size of packet in bytes. +func (pk *PacketBuffer) Size() int { + return pk.HeaderSize() + pk.Data.Size() +} + +// Views returns the underlying storage of the whole packet. +func (pk *PacketBuffer) Views() []buffer.View { + // Optimization for outbound packets that headers are in pk.header. + useHeader := true + for i := range pk.headers { + if !canUseHeader(&pk.headers[i]) { + useHeader = false + break + } + } + + dataViews := pk.Data.Views() + + var vs []buffer.View + if useHeader { + vs = make([]buffer.View, 0, 1+len(dataViews)) + vs = append(vs, pk.header.View()) + } else { + vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) + for i := range pk.headers { + if v := pk.headers[i].buf; len(v) > 0 { + vs = append(vs, v) + } + } + } + return append(vs, dataViews...) +} + +func canUseHeader(h *headerInfo) bool { + // h.offset will be negative if the header was pushed in to prependable + // portion, or doesn't matter when it's empty. + return len(h.buf) == 0 || h.offset < 0 +} + +func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + } + h.buf = buffer.View(pk.header.Prepend(size)) + h.offset = -pk.header.UsedLength() + return h.buf +} + +func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) + } + v, ok := pk.Data.PullUp(size) + if !ok { + return + } + pk.Data.TrimFront(size) + h.buf = v + return h.buf, true +} + +// Clone makes a shallow copy of pk. // -// FIXME(b/153685824): Data gets copied but not other header references. +// Clone should be called in such cases so that no modifications is done to +// underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - return &PacketBuffer{ + newPk := &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, Data: pk.Data.Clone(nil), - Header: pk.Header, - LinkHeader: pk.LinkHeader, - NetworkHeader: pk.NetworkHeader, - TransportHeader: pk.TransportHeader, + headers: pk.headers, + header: pk.header, Hash: pk.Hash, Owner: pk.Owner, EgressRoute: pk.EgressRoute, @@ -110,4 +245,55 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NetworkProtocolNumber: pk.NetworkProtocolNumber, NatDone: pk.NatDone, } + return newPk +} + +// headerInfo stores metadata about a header in a packet. +type headerInfo struct { + // buf is the memorized slice for both prepended and consumed header. + // When header is prepended, buf serves as memorized value, which is a slice + // of pk.header. When header is consumed, buf is the slice pulled out from + // pk.Data, which is the only place to hold this header. + buf buffer.View + + // offset will be a negative number denoting the offset where this header is + // from the end of pk.header, if it is prepended. Otherwise, zero. + offset int +} + +// PacketHeader is a handle object to a header in the underlying packet. +type PacketHeader struct { + pk *PacketBuffer + typ headerType +} + +// View returns the underlying storage of h. +func (h PacketHeader) View() buffer.View { + return h.pk.headers[h.typ].buf +} + +// Push pushes size bytes in the front of its residing packet, and returns the +// backing storage. Callers may only call one of Push or Consume once on each +// header in the lifetime of the underlying packet. +func (h PacketHeader) Push(size int) buffer.View { + return h.pk.push(h.typ, size) +} + +// Consume moves the first size bytes of the unparsed data portion in the packet +// to h, and returns the backing storage. In the case of data is shorter than +// size, consumed will be false, and the state of h will not be affected. +// Callers may only call one of Push or Consume once on each header in the +// lifetime of the underlying packet. +func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { + return h.pk.consume(h.typ, size) +} + +// PayloadSince returns packet payload starting from and including a particular +// header. This method isn't optimized and should be used in test only. +func PayloadSince(h PacketHeader) buffer.View { + var v buffer.View + for _, hinfo := range h.pk.headers[h.typ:] { + v = append(v, hinfo.buf...) + } + return append(v, h.pk.Data.ToView()...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go new file mode 100644 index 000000000..c6fa8da5f --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestPacketHeaderPush(t *testing.T) { + for _, test := range []struct { + name string + reserved int + link []byte + network []byte + transport []byte + data []byte + }{ + { + name: "construct empty packet", + }, + { + name: "construct link header only packet", + reserved: 60, + link: makeView(10), + }, + { + name: "construct link and network header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + }, + { + name: "construct header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + }, + { + name: "construct data only packet", + data: makeView(40), + }, + { + name: "construct L3 packet", + reserved: 60, + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + { + name: "construct L2 packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + allHdrSize := len(test.link) + len(test.network) + len(test.transport) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Push headers. + if v := test.transport; len(v) > 0 { + copy(pk.TransportHeader().Push(len(v)), v) + } + if v := test.network; len(v) > 0 { + copy(pk.NetworkHeader().Push(len(v)), v) + } + if v := test.link; len(v) > 0 { + copy(pk.LinkHeader().Push(len(v)), v) + } + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), allHdrSize+len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), + concatViews(test.link, test.network, test.transport, test.data)) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(test.link, test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(test.transport, test.data)) + }) + } +} + +func TestPacketHeaderConsume(t *testing.T) { + for _, test := range []struct { + name string + data []byte + link int + network int + transport int + }{ + { + name: "parse L2 packet", + data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), + link: 10, + network: 20, + transport: 30, + }, + { + name: "parse L3 packet", + data: concatViews(makeView(20), makeView(30), makeView(40)), + network: 20, + transport: 30, + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Consume headers. + if size := test.link; size > 0 { + if _, ok := pk.LinkHeader().Consume(size); !ok { + t.Fatalf("pk.LinkHeader().Consume() = false, want true") + } + } + if size := test.network; size > 0 { + if _, ok := pk.NetworkHeader().Consume(size); !ok { + t.Fatalf("pk.NetworkHeader().Consume() = false, want true") + } + } + if size := test.transport; size > 0 { + if _, ok := pk.TransportHeader().Consume(size); !ok { + t.Fatalf("pk.TransportHeader().Consume() = false, want true") + } + } + + allHdrSize := test.link + test.network + test.transport + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), 0; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), 0; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + // After state of pk. + var ( + link = test.data[:test.link] + network = test.data[test.link:][:test.network] + transport = test.data[test.link+test.network:][:test.transport] + payload = test.data[allHdrSize:] + ) + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(link, network, transport, payload)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(network, transport, payload)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(transport, payload)) + }) + } +} + +func TestPacketHeaderConsumeDataTooShort(t *testing.T) { + data := makeView(10) + + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + // Consume should fail if pkt.Data is too short. + if _, ok := pk.LinkHeader().Consume(11); ok { + t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.NetworkHeader().Consume(11); ok { + t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.TransportHeader().Consume(11); ok { + t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") + } + + // Check packet should look the same as initial packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(data).ToVectorisedView(), + }) +} + +func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Second push should have panicked") + }) + } +} + +func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { + if _, ok := h.Consume(headerSize); !ok { + t.Fatal("First consume should succeed") + } + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Second consume should have panicked") + }) + } +} + +func TestPacketHeaderPushThenConsumePanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Consume should have panicked") + }) + } +} + +func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Consume(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Push should have panicked") + }) + } +} + +func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { + t.Helper() + reserved := opts.ReserveHeaderBytes + if got, want := pk.ReservedHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), 0; got != want { + t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) + } + data := opts.Data.ToView() + if got, want := pk.Size(), len(data); got != want { + t.Errorf("Initial pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) + // Check the initial values for each header. + checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) + checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) + checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) + // Check the initial valies for PayloadSince. + checkViewEqual(t, "Initial PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), data) +} + +func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { + t.Helper() + checkViewEqual(t, name+".View()", h.View(), want) +} + +func checkViewEqual(t *testing.T, what string, got, want buffer.View) { + t.Helper() + if !bytes.Equal(got, want) { + t.Errorf("%s = %x, want %x", what, got, want) + } +} + +func makeView(size int) buffer.View { + b := byte(size) + return bytes.Repeat([]byte{b}, size) +} + +func concatViews(views ...buffer.View) buffer.View { + var all buffer.View + for _, v := range views { + all = append(all, v...) + } + return all +} diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9ce0a2c22..e267bebb0 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -173,7 +173,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf } // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + numBytes := pkt.Size() err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { @@ -203,8 +203,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Header.UsedLength() - writtenBytes += pb.Data.Size() + writtenBytes += pb.Size() } r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index fe1c1b8a4..0273b3c63 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -102,7 +102,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -118,7 +118,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -143,10 +143,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. - pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) - pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] - pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] - pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) + hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + hdr[dstAddrOffset] = r.RemoteAddress[0] + hdr[srcAddrOffset] = f.id.LocalAddress[0] + hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { f.HandlePacket(r, pkt) @@ -249,12 +249,10 @@ func (*fakeNetworkProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(fakeNetHeaderLen) return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -315,9 +313,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -327,9 +325,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -339,9 +337,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -350,9 +348,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -362,9 +360,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -383,11 +381,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro } func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + })) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -442,9 +439,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -2285,9 +2282,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -2367,9 +2364,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) pkt, ok := ep2.Read() if !ok { @@ -2377,8 +2374,8 @@ func TestNICForwarding(t *testing.T) { } // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { - t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { + t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) } // Test that forwarding increments Tx stats correctly. diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 73dada928..1339edc2d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -128,11 +128,10 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) } func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { @@ -166,11 +165,10 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) } func TestTransportDemuxerRegister(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 7e8b84867..6c6e44468 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -84,16 +84,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) - hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(v).ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, + Data: buffer.View(v).ToVectorisedView(), + }) + _ = pkt.TransportHeader().Push(fakeTransHeaderLen) + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, nil, err } @@ -328,13 +328,8 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) - if !ok { - return false - } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(fakeTransHeaderLen) - return true + _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) + return ok } func fakeTransFactory() stack.TransportProtocol { @@ -382,9 +377,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -393,9 +388,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -404,9 +399,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -459,9 +454,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -470,9 +465,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -481,9 +476,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -636,9 +631,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: req.ToVectorisedView(), - }) + })) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -655,10 +650,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.NetworkHeader[0]; dst != 3 { + nh := stack.PayloadSince(p.Pkt.NetworkHeader()) + if dst := nh[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.NetworkHeader[1]; src != 1 { + if src := nh[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 0ff3a2b89..9f0dd4d6d 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -80,9 +80,9 @@ func TestPingMulticastBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { @@ -102,9 +102,9 @@ func TestPingMulticastBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } tests := []struct { @@ -204,7 +204,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader) + src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } @@ -252,9 +252,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { @@ -280,9 +280,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } tests := []struct { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 4612be4e7..bd6f49eb8 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -430,9 +430,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + }) + pkt.Owner = owner - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -447,15 +450,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + pkt.Data = data.ToVectorisedView() + if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: data.ToVectorisedView(), - TransportHeader: buffer.View(icmpv4), - Owner: owner, - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -463,9 +463,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + }) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) @@ -477,15 +479,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err dataVV := data.ToVectorisedView() icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) + pkt.Data = dataVV if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: dataVV, - TransportHeader: buffer.View(icmpv6), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } // checkV4MappedLocked determines the effective network protocol and converts @@ -748,14 +747,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.TransportHeader) + h := header.ICMPv4(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.TransportHeader) + h := header.ICMPv6(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -791,7 +794,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // ICMP socket's data includes ICMP header. - packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data = pkt.TransportHeader().View().ToVectorisedView() packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index df478115d..1b03ad6bb 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -433,9 +433,9 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet // TODO(gvisor.dev/issue/173): Return network protocol. - if len(pkt.LinkHeader) > 0 { + if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. - hdr := header.Ethernet(pkt.LinkHeader) + hdr := header.Ethernet(pkt.LinkHeader().View()) packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), @@ -458,9 +458,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, case tcpip.PacketHost: packet.data = pkt.Data case tcpip.PacketOutgoing: - // Strip Link Header from the Header. - pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) - combinedVV := pkt.Header.View().ToVectorisedView() + // Strip Link Header. + var combinedVV buffer.VectorisedView + if v := pkt.NetworkHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } + if v := pkt.TransportHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } combinedVV.Append(pkt.Data) packet.data = combinedVV default: @@ -471,9 +476,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - var combinedVV buffer.VectorisedView if pkt.PktType != tcpip.PacketOutgoing { - if len(pkt.LinkHeader) == 0 { + if pkt.LinkHeader().View().IsEmpty() { // We weren't provided with an actual ethernet header, // so fake one. ethFields := header.EthernetFields{ @@ -485,19 +489,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, fakeHeader.Encode(ðFields) linkHeader = buffer.View(fakeHeader) } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) } - combinedVV = linkHeader.ToVectorisedView() - } - if pkt.PktType == tcpip.PacketOutgoing { - // For outgoing packets the Link, Network and Transport - // headers are in the pkt.Header fields normally unless - // a Raw socket is in use in which case pkt.Header could - // be nil. - combinedVV.AppendView(pkt.Header.View()) + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } else { + packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } - combinedVV.Append(pkt.Data) - packet.data = combinedVV } packet.timestampNS = ep.stack.Clock().NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index f85a68554..edc2b5b61 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -352,18 +352,23 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } if e.hdrIncluded { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { + }) + if err := route.WriteHeaderIncludedPacket(pkt); err != nil { return 0, nil, err } } else { - hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(payloadBytes).ToVectorisedView(), - Owner: e.owner, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()), + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) + pkt.Owner = e.owner + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: e.TransProto, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, pkt); err != nil { return 0, nil, err } } @@ -691,12 +696,13 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // slice. Save/restore doesn't support overlapping slices and will fail. var combinedVV buffer.VectorisedView if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { - headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) - headers = append(headers, pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(network)+len(transport)) + headers = append(headers, network...) + headers = append(headers, transport...) combinedVV = headers.ToVectorisedView() } else { - combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() } combinedVV.Append(pkt.Data) packet.data = combinedVV diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 46702906b..290172ac9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -746,11 +746,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) - hdr := &pkt.Header - packetSize := pkt.Data.Size() - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) - pkt.TransportHeader = buffer.View(tcp) + tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) tcp.Encode(&header.TCPFields{ SrcPort: tf.id.LocalPort, DstPort: tf.id.RemotePort, @@ -762,8 +758,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta }) copy(tcp[header.TCPMinimumSize:], tf.opts) - length := uint16(hdr.UsedLength() + packetSize) - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We @@ -801,17 +796,18 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso packetSize = size } size -= packetSize - var pkt stack.PacketBuffer - pkt.Header = buffer.NewPrependable(hdrSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrSize, + }) pkt.Hash = tf.txHash pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() data.ReadToVV(&pkt.Data, packetSize) - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) - pkts.PushBack(&pkt) + pkts.PushBack(pkt) } if tf.ttl == 0 { @@ -837,12 +833,12 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := &stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - Data: data, - Hash: tf.txHash, - Owner: owner, - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, + Data: data, + }) + pkt.Hash = tf.txHash + pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 49a673b42..c5afa2680 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,7 +21,6 @@ package tcp import ( - "fmt" "runtime" "strings" "time" @@ -547,22 +546,22 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) if !ok { return false } // If the header has options, pull those up as well. if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(offset) - if !ok { - panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) - } + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(len(hdr)) - return true + _, ok = pkt.TransportHeader().Consume(hdrLen) + return ok } // NewProtocol returns a TCP transport protocol. diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index bb60dc29d..94307d31a 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -68,7 +68,7 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) - s.hdr = header.TCP(pkt.TransportHeader) + s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() return s } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 37e7767d6..927bc71e0 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -257,8 +257,8 @@ func (c *Context) GetPacket() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -284,8 +284,8 @@ func (c *Context) GetPacketNonBlocking() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b @@ -318,9 +318,10 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -374,26 +375,29 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: s, }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegment(payload, h), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacketWithAddrs builds and sends a TCP segment(with the provided payload // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendAck sends an ACK packet. @@ -514,9 +518,8 @@ func (c *Context) GetV6Packet() []byte { if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b @@ -566,9 +569,10 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) } // CreateConnected creates a connected TCP endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4a2b6c03a..73608783c 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -986,13 +986,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { - // Allocate a buffer for the UDP header. - hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: data, + }) + pkt.Owner = owner - // Initialize the header. - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ SrcPort: localPort, DstPort: remotePort, @@ -1019,12 +1022,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Protocol: ProtocolNumber, TTL: ttl, TOS: tos, - }, &stack.PacketBuffer{ - Header: hdr, - Data: data, - TransportHeader: buffer.View(udp), - Owner: owner, - }); err != nil { + }, pkt); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } @@ -1372,7 +1370,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() @@ -1443,9 +1441,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Save any useful information from the network header to the packet. switch r.NetProto { case header.IPv4ProtocolNumber: - packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: - packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS() } // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 0e7464e3a..63d4bed7c 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -82,7 +82,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() @@ -130,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -139,22 +139,21 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // For example, a raw or packet socket may use what UDP // considers an unreachable destination. Thus we deep copy pkt // to prevent multiple ownership and SR errors. - newHeader := append(buffer.View(nil), pkt.NetworkHeader...) - newHeader = append(newHeader, pkt.TransportHeader...) + newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...) + newHeader = append(newHeader, pkt.TransportHeader().View()...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4DstUnreachable) - pkt.SetCode(header.ICMPv4PortUnreachable) - pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, }) + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -175,24 +174,24 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + payloadLen := len(network) + len(transport) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) + payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport}) payload.Append(pkt.Data) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) - pkt.SetType(header.ICMPv6DstUnreachable) - pkt.SetCode(header.ICMPv6PortUnreachable) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, }) + icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) } return true } @@ -215,14 +214,8 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Packet is too small - return false - } - pkt.TransportHeader = h - pkt.Data.TrimFront(header.UDPMinimumSize) - return true + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + return ok } // NewProtocol returns a UDP transport protocol. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 1a32622ca..71776d6db 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -388,8 +388,8 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() h := flow.header4Tuple(outgoing) checkers = append( @@ -410,14 +410,14 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } else { buf := c.buildV6Packet(payload, &h) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } } @@ -804,9 +804,9 @@ func TestV4ReadSelfSource(t *testing.T) { h.srcAddr = h.dstAddr buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) @@ -1766,9 +1766,8 @@ func TestV4UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1844,9 +1843,8 @@ func TestV6UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1897,9 +1895,9 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetLength(u.Length() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { @@ -1952,9 +1950,9 @@ func TestShortHeader(t *testing.T) { copy(buf[header.IPv6MinimumSize:], udpHdr) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) @@ -1986,9 +1984,9 @@ func TestIncrementChecksumErrorsV4(t *testing.T) { } } - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2019,9 +2017,9 @@ func TestIncrementChecksumErrorsV6(t *testing.T) { u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(u.Checksum() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2049,9 +2047,9 @@ func TestPayloadModifiedV4(t *testing.T) { buf := c.buildV4Packet(payload, &h) // Modify the payload so that the checksum value in the UDP header will be incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2079,9 +2077,9 @@ func TestPayloadModifiedV6(t *testing.T) { buf := c.buildV6Packet(payload, &h) // Modify the payload so that the checksum value in the UDP header will be incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2110,9 +2108,9 @@ func TestChecksumZeroV4(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv4MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 0 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2141,9 +2139,9 @@ func TestChecksumZeroV6(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { -- cgit v1.2.3