diff options
Diffstat (limited to 'pkg')
99 files changed, 2744 insertions, 1225 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 05ca5342f..b5c5cc20b 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -41,6 +41,7 @@ go_library( "mm.go", "netdevice.go", "netfilter.go", + "netfilter_ipv6.go", "netlink.go", "netlink_route.go", "poll.go", diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go index 5c6ffe4a3..7e30483ee 100644 --- a/pkg/abi/linux/fuse.go +++ b/pkg/abi/linux/fuse.go @@ -246,3 +246,58 @@ type FUSEInitOut struct { _ [8]uint32 } + +// FUSEGetAttrIn is the request sent by the kernel to the daemon, +// to get the attribute of a inode. +// +// +marshal +type FUSEGetAttrIn struct { + // GetAttrFlags specifies whether getattr request is sent with a nodeid or + // with a file handle. + GetAttrFlags uint32 + + _ uint32 + + // Fh is the file handler when GetAttrFlags has FUSE_GETATTR_FH bit. If + // used, the operation is analogous to fstat(2). + Fh uint64 +} + +// FUSEAttr is the struct used in the reponse FUSEGetAttrOut. +// +// +marshal +type FUSEAttr struct { + Ino uint64 + Size uint64 + Blocks uint64 + Atime uint64 + Mtime uint64 + Ctime uint64 + AtimeNsec uint32 + MtimeNsec uint32 + CtimeNsec uint32 + Mode uint32 + Nlink uint32 + UID uint32 + GID uint32 + Rdev uint32 + BlkSize uint32 + _ uint32 +} + +// FUSEGetAttrOut is the reply sent by the daemon to the kernel +// for FUSEGetAttrIn. +// +// +marshal +type FUSEGetAttrOut struct { + // AttrValid and AttrValidNsec describe the attribute cache duration + AttrValid uint64 + + // AttrValidNsec is the nanosecond part of the attribute cache duration + AttrValidNsec uint32 + + _ uint32 + + // Attr contains the metadata returned from the FUSE server + Attr FUSEAttr +} diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 9c27f7bb2..91e35366f 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -412,7 +412,7 @@ func (ke *KernelIPTGetEntries) SizeBytes() int { func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { ke.IPTGetEntries.MarshalBytes(dst) marshalledUntil := ke.IPTGetEntries.SizeBytes() - for i := 0; i < len(ke.Entrytable); i++ { + for i := range ke.Entrytable { ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) marshalledUntil += ke.Entrytable[i].SizeBytes() } @@ -422,7 +422,7 @@ func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { ke.IPTGetEntries.UnmarshalBytes(src) unmarshalledUntil := ke.IPTGetEntries.SizeBytes() - for i := 0; i < len(ke.Entrytable); i++ { + for i := range ke.Entrytable { ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) unmarshalledUntil += ke.Entrytable[i].SizeBytes() } diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go new file mode 100644 index 000000000..9bb9efb10 --- /dev/null +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -0,0 +1,310 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" +) + +// This file contains structures required to support IPv6 netfilter and +// ip6tables. Some constants and structs are equal to their IPv4 analogues, and +// are only distinguished by context (e.g. whether used on an IPv4 of IPv6 +// socket). + +// Socket options for SOL_SOCLET. These correspond to values in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +const ( + IP6T_BASE_CTL = 64 + IP6T_SO_SET_REPLACE = IPT_BASE_CTL + IP6T_SO_SET_ADD_COUNTERS = IPT_BASE_CTL + 1 + IP6T_SO_SET_MAX = IPT_SO_SET_ADD_COUNTERS + + IP6T_SO_GET_INFO = IPT_BASE_CTL + IP6T_SO_GET_ENTRIES = IPT_BASE_CTL + 1 + IP6T_SO_GET_REVISION_MATCH = IPT_BASE_CTL + 4 + IP6T_SO_GET_REVISION_TARGET = IPT_BASE_CTL + 5 + IP6T_SO_GET_MAX = IP6T_SO_GET_REVISION_TARGET +) + +// IP6T_ORIGINAL_DST is the ip6tables SOL_IPV6 socket option. Corresponds to +// the value in include/uapi/linux/netfilter_ipv6/ip6_tables.h. +// TODO(gvisor.dev/issue/3549): Support IPv6 original destination. +const IP6T_ORIGINAL_DST = 80 + +// IP6TReplace is the argument for the IP6T_SO_SET_REPLACE sockopt. It +// corresponds to struct ip6t_replace in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +// +// +marshal +type IP6TReplace struct { + Name TableName + ValidHooks uint32 + NumEntries uint32 + Size uint32 + HookEntry [NF_INET_NUMHOOKS]uint32 + Underflow [NF_INET_NUMHOOKS]uint32 + NumCounters uint32 + Counters uint64 // This is really a *XTCounters. + // Entries is omitted here because it would cause IP6TReplace to be an + // extra byte longer (see http://www.catb.org/esr/structure-packing/). + // Entries [0]IP6TEntry +} + +// SizeOfIP6TReplace is the size of an IP6TReplace. +const SizeOfIP6TReplace = 96 + +// KernelIP6TGetEntries is identical to IP6TGetEntries, but includes the +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. +type KernelIP6TGetEntries struct { + IPTGetEntries + Entrytable []KernelIP6TEntry +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIP6TGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := range ke.Entrytable { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := range ke.Entrytable { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIP6TGetEntries) Packed() bool { + // KernelIP6TGetEntries isn't packed because the ke.Entrytable contains + // an indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIP6TGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIP6TGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := task.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results + // in a partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIP6TGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, + // fall back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIP6TGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) +} + +func (ke *KernelIP6TGetEntries) marshalAll(task marshal.Task) []byte { + buf := task.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIP6TGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil) + +// IP6TEntry is an iptables rule. It corresponds to struct ip6t_entry in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +// +// +marshal +type IP6TEntry struct { + // IPv6 is used to filter packets based on the IPv6 header. + IPv6 IP6TIP + + // NFCache relates to kernel-internal caching and isn't used by + // userspace. + NFCache uint32 + + // TargetOffset is the byte offset from the beginning of this IPTEntry + // to the start of the entry's target. + TargetOffset uint16 + + // NextOffset is the byte offset from the beginning of this IPTEntry to + // the start of the next entry. It is thus also the size of the entry. + NextOffset uint16 + + // Comeback is a return pointer. It is not used by userspace. + Comeback uint32 + + _ [4]byte + + // Counters holds the packet and byte counts for this rule. + Counters XTCounters + + // Elems holds the data for all this rule's matches followed by the + // target. It is variable length -- users have to iterate over any + // matches and use TargetOffset and NextOffset to make sense of the + // data. + // + // Elems is omitted here because it would cause IPTEntry to be an extra + // byte larger (see http://www.catb.org/esr/structure-packing/). + // + // Elems [0]byte +} + +// SizeOfIP6TEntry is the size of an IP6TEntry. +const SizeOfIP6TEntry = 168 + +// KernelIP6TEntry is identical to IP6TEntry, but includes the Elems field. +// KernelIP6TEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. +type KernelIP6TEntry struct { + Entry IP6TEntry + + // Elems holds the data for all this rule's matches followed by the + // target. It is variable length -- users have to iterate over any + // matches and use TargetOffset and NextOffset to make sense of the + // data. + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIP6TEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) +} + +// IP6TIP contains information for matching a packet's IP header. +// It corresponds to struct ip6t_ip6 in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +// +// +marshal +type IP6TIP struct { + // Src is the source IP address. + Src Inet6Addr + + // Dst is the destination IP address. + Dst Inet6Addr + + // SrcMask is the source IP mask. + SrcMask Inet6Addr + + // DstMask is the destination IP mask. + DstMask Inet6Addr + + // InputInterface is the input network interface. + InputInterface [IFNAMSIZ]byte + + // OutputInterface is the output network interface. + OutputInterface [IFNAMSIZ]byte + + // InputInterfaceMask is the input interface mask. + InputInterfaceMask [IFNAMSIZ]byte + + // OuputInterfaceMask is the output interface mask. + OutputInterfaceMask [IFNAMSIZ]byte + + // Protocol is the transport protocol. + Protocol uint16 + + // TOS matches TOS flags when Flags indicates filtering by TOS. + TOS uint8 + + // Flags define matching behavior for the IP header. + Flags uint8 + + // InverseFlags invert the meaning of fields in struct IPTIP. See the + // IP6T_INV_* flags. + InverseFlags uint8 + + // Linux defines in6_addr (Inet6Addr for us) as the union of a + // 16-element byte array and a 4-element 32-bit integer array, so the + // whole struct is 4-byte aligned. + _ [3]byte +} + +const SizeOfIP6TIP = 136 + +// Flags in IP6TIP.InverseFlags. Corresponding constants are in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +const ( + // Invert the meaning of InputInterface. + IP6T_INV_VIA_IN = 0x01 + // Invert the meaning of OutputInterface. + IP6T_INV_VIA_OUT = 0x02 + // Invert the meaning of TOS. + IP6T_INV_TOS = 0x04 + // Invert the meaning of Src. + IP6T_INV_SRCIP = 0x08 + // Invert the meaning of Dst. + IP6T_INV_DSTIP = 0x10 + // Invert the meaning of the IPT_F_FRAG flag. + IP6T_INV_FRAG = 0x20 + // Enable all flags. + IP6T_INV_MASK = 0x7F +) diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index 565dd550e..bf73271c6 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -36,6 +36,9 @@ func TestSizes(t *testing.T) { {XTEntryTarget{}, SizeOfXTEntryTarget}, {XTErrorTarget{}, SizeOfXTErrorTarget}, {XTStandardTarget{}, SizeOfXTStandardTarget}, + {IP6TReplace{}, SizeOfIP6TReplace}, + {IP6TEntry{}, SizeOfIP6TEntry}, + {IP6TIP{}, SizeOfIP6TIP}, } for _, tc := range testCases { diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 693996c01..e37c8727d 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -259,6 +259,11 @@ type InetMulticastRequestWithNIC struct { InterfaceIndex int32 } +// Inet6Addr is struct in6_addr, from uapi/linux/in6.h. +// +// +marshal +type Inet6Addr [16]byte + // SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h. type SockAddrInet6 struct { Family uint16 diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index c757583e0..b78fdab7a 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -62,6 +62,8 @@ func TestVersion(t *testing.T) { } func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { + b.ReportAllocs() + // See above. serverSocket, clientSocket, err := unet.SocketPair(false) if err != nil { diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 7cec0e86d..02e665345 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -66,14 +66,17 @@ const ( var dataPool = sync.Pool{ New: func() interface{} { // These buffers are used for decoding without a payload. - return make([]byte, initialBufferLength) + // We need to return a pointer to avoid unnecessary allocations + // (see https://staticcheck.io/docs/checks#SA6002). + b := make([]byte, initialBufferLength) + return &b }, } // send sends the given message over the socket. func send(s *unet.Socket, tag Tag, m message) error { - data := dataPool.Get().([]byte) - dataBuf := buffer{data: data[:0]} + data := dataPool.Get().(*[]byte) + dataBuf := buffer{data: (*data)[:0]} if log.IsLogging(log.Debug) { log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String()) @@ -141,7 +144,7 @@ func send(s *unet.Socket, tag Tag, m message) error { } // All set. - dataPool.Put(dataBuf.data) + dataPool.Put(&dataBuf.data) return nil } @@ -227,12 +230,29 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, // Not yet initialized. var dataBuf buffer + var vecs [][]byte + + appendBuffer := func(size int) *[]byte { + // Pull a data buffer from the pool. + datap := dataPool.Get().(*[]byte) + data := *datap + if size > len(data) { + // Create a larger data buffer. + data = make([]byte, size) + datap = &data + } else { + // Limit the data buffer. + data = data[:size] + } + dataBuf = buffer{data: data} + vecs = append(vecs, data) + return datap + } // Read the rest of the payload. // // This requires some special care to ensure that the vectors all line // up the way they should. We do this to minimize copying data around. - var vecs [][]byte if payloader, ok := m.(payloader); ok { fixedSize := payloader.FixedSize() @@ -246,22 +266,8 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, } if fixedSize != 0 { - // Pull a data buffer from the pool. - data := dataPool.Get().([]byte) - if int(fixedSize) > len(data) { - // Create a larger data buffer, ensuring - // sufficient capicity for the message. - data = make([]byte, fixedSize) - defer dataPool.Put(data) - dataBuf = buffer{data: data} - vecs = append(vecs, data) - } else { - // Limit the data buffer, and make sure it - // gets filled before the payload buffer. - defer dataPool.Put(data) - dataBuf = buffer{data: data[:fixedSize]} - vecs = append(vecs, data[:fixedSize]) - } + datap := appendBuffer(int(fixedSize)) + defer dataPool.Put(datap) } // Include the payload. @@ -274,20 +280,8 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, vecs = append(vecs, p) } } else if remaining != 0 { - // Pull a data buffer from the pool. - data := dataPool.Get().([]byte) - if int(remaining) > len(data) { - // Create a larger data buffer. - data = make([]byte, remaining) - defer dataPool.Put(data) - dataBuf = buffer{data: data} - vecs = append(vecs, data) - } else { - // Limit the data buffer. - defer dataPool.Put(data) - dataBuf = buffer{data: data[:remaining]} - vecs = append(vecs, data[:remaining]) - } + datap := appendBuffer(int(remaining)) + defer dataPool.Put(datap) } if len(vecs) > 0 { diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index 3668fcad7..e7406b374 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -182,6 +182,8 @@ func TestSendClosed(t *testing.T) { } func BenchmarkSendRecv(b *testing.B) { + b.ReportAllocs() + server, client, err := unet.SocketPair(false) if err != nil { b.Fatalf("socketpair got err %v expected nil", err) diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refs_vfs2/refs_template.go index 3e5b458c7..99c43c065 100644 --- a/pkg/refs_vfs2/refs_template.go +++ b/pkg/refs_vfs2/refs_template.go @@ -59,7 +59,7 @@ func (r *Refs) finalize() { note = "(Leak checker uninitialized): " } if n := r.ReadRefs(); n != 0 { - log.Warningf("%sAtomicRefCount %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) + log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) } } diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index c5fca2ba3..29aeaab8c 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -5,7 +5,11 @@ package(licenses = ["notice"]) go_binary( name = "victim", testonly = 1, - srcs = ["seccomp_test_victim.go"], + srcs = [ + "seccomp_test_victim.go", + "seccomp_test_victim_amd64.go", + "seccomp_test_victim_arm64.go", + ], deps = [":seccomp"], ) diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 88766f33b..5238df8bd 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -91,12 +91,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "Single syscall allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "Single syscall disallowed", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -125,22 +125,22 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "Multiple rulesets allowed (1a)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x1}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "Multiple rulesets allowed (1b)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "Multiple rulesets allowed (2)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "Multiple rulesets allowed (2)", - data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_KILL_THREAD, }, }, @@ -160,42 +160,42 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "Multiple syscalls allowed (1)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "Multiple syscalls allowed (3)", - data: seccompData{nr: 3, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "Multiple syscalls allowed (5)", - data: seccompData{nr: 5, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "Multiple syscalls disallowed (0)", - data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "Multiple syscalls disallowed (2)", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "Multiple syscalls disallowed (4)", - data: seccompData{nr: 4, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "Multiple syscalls disallowed (6)", - data: seccompData{nr: 6, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "Multiple syscalls disallowed (100)", - data: seccompData{nr: 100, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -231,7 +231,7 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "Syscall disallowed, action trap", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, + data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -254,12 +254,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xf}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "Syscall argument disallowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xe}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -284,12 +284,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "Syscall argument allowed, two rules", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "Syscall argument allowed, two rules", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xe}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}}, want: linux.SECCOMP_RET_ALLOW, }, }, @@ -315,7 +315,7 @@ func TestBasic(t *testing.T) { desc: "64bit syscall argument allowed", data: seccompData{ nr: 1, - arch: linux.AUDIT_ARCH_X86_64, + arch: LINUX_AUDIT_ARCH, args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, }, want: linux.SECCOMP_RET_ALLOW, @@ -324,7 +324,7 @@ func TestBasic(t *testing.T) { desc: "64bit syscall argument disallowed", data: seccompData{ nr: 1, - arch: linux.AUDIT_ARCH_X86_64, + arch: LINUX_AUDIT_ARCH, args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, }, want: linux.SECCOMP_RET_TRAP, @@ -333,7 +333,7 @@ func TestBasic(t *testing.T) { desc: "64bit syscall argument disallowed", data: seccompData{ nr: 1, - arch: linux.AUDIT_ARCH_X86_64, + arch: LINUX_AUDIT_ARCH, args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, }, want: linux.SECCOMP_RET_TRAP, @@ -358,32 +358,32 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "GreaterThan: Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xffffffff}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "GreaterThan: Syscall argument disallowed (equal)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xffffffff}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "Syscall argument disallowed (smaller)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x0, 0xffffffff}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "GreaterThan2: Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xfbcd000d}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xfbcd000d}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "GreaterThan2: Syscall argument disallowed (equal)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xabcd000d}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "GreaterThan2: Syscall argument disallowed (smaller)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xa000ffff}}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -405,12 +405,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "IP: Syscall instruction pointer allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x7aabbccdd}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "IP: Syscall instruction pointer disallowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x711223344}, + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -466,7 +466,7 @@ func TestRandom(t *testing.T) { t.Fatalf("bpf.Compile() got error: %v", err) } for i := uint32(0); i < 200; i++ { - data := seccompData{nr: i, arch: linux.AUDIT_ARCH_X86_64} + data := seccompData{nr: i, arch: LINUX_AUDIT_ARCH} got, err := bpf.Exec(p, data.asInput()) if err != nil { t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i) diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go index da6b9eaaf..fe157f539 100644 --- a/pkg/seccomp/seccomp_test_victim.go +++ b/pkg/seccomp/seccomp_test_victim.go @@ -31,7 +31,6 @@ func main() { syscalls := seccomp.SyscallRules{ syscall.SYS_ACCEPT: {}, - syscall.SYS_ARCH_PRCTL: {}, syscall.SYS_BIND: {}, syscall.SYS_BRK: {}, syscall.SYS_CLOCK_GETTIME: {}, @@ -41,7 +40,6 @@ func main() { syscall.SYS_DUP3: {}, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, - syscall.SYS_EPOLL_WAIT: {}, syscall.SYS_EPOLL_PWAIT: {}, syscall.SYS_EXIT: {}, syscall.SYS_EXIT_GROUP: {}, @@ -68,8 +66,6 @@ func main() { syscall.SYS_MUNLOCK: {}, syscall.SYS_MUNMAP: {}, syscall.SYS_NANOSLEEP: {}, - syscall.SYS_NEWFSTATAT: {}, - syscall.SYS_OPEN: {}, syscall.SYS_PPOLL: {}, syscall.SYS_PREAD64: {}, syscall.SYS_PSELECT6: {}, @@ -97,6 +93,9 @@ func main() { syscall.SYS_WRITE: {}, syscall.SYS_WRITEV: {}, } + + arch_syscalls(syscalls) + die := *dieFlag if !die { syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{ diff --git a/pkg/seccomp/seccomp_test_victim_amd64.go b/pkg/seccomp/seccomp_test_victim_amd64.go new file mode 100644 index 000000000..5dfc68e25 --- /dev/null +++ b/pkg/seccomp/seccomp_test_victim_amd64.go @@ -0,0 +1,32 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test binary used to test that seccomp filters are properly constructed and +// indeed kill the process on violation. + +// +build amd64 + +package main + +import ( + "gvisor.dev/gvisor/pkg/seccomp" + "syscall" +) + +func arch_syscalls(syscalls seccomp.SyscallRules) { + syscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{} + syscalls[syscall.SYS_EPOLL_WAIT] = []seccomp.Rule{} + syscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{} + syscalls[syscall.SYS_OPEN] = []seccomp.Rule{} +} diff --git a/pkg/seccomp/seccomp_test_victim_arm64.go b/pkg/seccomp/seccomp_test_victim_arm64.go new file mode 100644 index 000000000..5184d8ac4 --- /dev/null +++ b/pkg/seccomp/seccomp_test_victim_arm64.go @@ -0,0 +1,29 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test binary used to test that seccomp filters are properly constructed and +// indeed kill the process on violation. + +// +build arm64 + +package main + +import ( + "gvisor.dev/gvisor/pkg/seccomp" + "syscall" +) + +func arch_syscalls(syscalls seccomp.SyscallRules) { + syscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{} +} diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index fd95eb2d2..0f433ee79 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -101,6 +101,8 @@ func NewFloatingPointData() *FloatingPointData { // State contains the common architecture bits for aarch64 (the build tag of this // file ensures it's only built on aarch64). +// +// +stateify savable type State struct { // The system registers. Regs Registers diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index cabbf60e0..550741d8c 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -73,6 +73,8 @@ const ( ) // context64 represents an ARM64 context. +// +// +stateify savable type context64 struct { State sigFPState []aarch64FPState // fpstate to be restored on sigreturn. diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index a1405f7c3..83c24ec25 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -226,3 +226,99 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr } return fd.VFSFileDescription(), nil } + +// statFromFUSEAttr makes attributes from linux.FUSEAttr to linux.Statx. The +// opts.Sync attribute is ignored since the synchronization is handled by the +// FUSE server. +func statFromFUSEAttr(attr linux.FUSEAttr, mask, devMinor uint32) linux.Statx { + var stat linux.Statx + stat.Blksize = attr.BlkSize + stat.DevMajor, stat.DevMinor = linux.UNNAMED_MAJOR, devMinor + + rdevMajor, rdevMinor := linux.DecodeDeviceID(attr.Rdev) + stat.RdevMajor, stat.RdevMinor = uint32(rdevMajor), rdevMinor + + if mask&linux.STATX_MODE != 0 { + stat.Mode = uint16(attr.Mode) + } + if mask&linux.STATX_NLINK != 0 { + stat.Nlink = attr.Nlink + } + if mask&linux.STATX_UID != 0 { + stat.UID = attr.UID + } + if mask&linux.STATX_GID != 0 { + stat.GID = attr.GID + } + if mask&linux.STATX_ATIME != 0 { + stat.Atime = linux.StatxTimestamp{ + Sec: int64(attr.Atime), + Nsec: attr.AtimeNsec, + } + } + if mask&linux.STATX_MTIME != 0 { + stat.Mtime = linux.StatxTimestamp{ + Sec: int64(attr.Mtime), + Nsec: attr.MtimeNsec, + } + } + if mask&linux.STATX_CTIME != 0 { + stat.Ctime = linux.StatxTimestamp{ + Sec: int64(attr.Ctime), + Nsec: attr.CtimeNsec, + } + } + if mask&linux.STATX_INO != 0 { + stat.Ino = attr.Ino + } + if mask&linux.STATX_SIZE != 0 { + stat.Size = attr.Size + } + if mask&linux.STATX_BLOCKS != 0 { + stat.Blocks = attr.Blocks + } + return stat +} + +// Stat implements kernfs.Inode.Stat. +func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + fusefs := fs.Impl().(*filesystem) + conn := fusefs.conn + task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) + if task == nil { + log.Warningf("couldn't get kernel task from context") + return linux.Statx{}, syserror.EINVAL + } + + var in linux.FUSEGetAttrIn + // We don't set any attribute in the request, because in VFS2 fstat(2) will + // finally be translated into vfs.FilesystemImpl.StatAt() (see + // pkg/sentry/syscalls/linux/vfs2/stat.go), resulting in the same flow + // as stat(2). Thus GetAttrFlags and Fh variable will never be used in VFS2. + req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.Ino(), linux.FUSE_GETATTR, &in) + if err != nil { + return linux.Statx{}, err + } + + res, err := conn.Call(task, req) + if err != nil { + return linux.Statx{}, err + } + if err := res.Error(); err != nil { + return linux.Statx{}, err + } + + var out linux.FUSEGetAttrOut + if err := res.UnmarshalPayload(&out); err != nil { + return linux.Statx{}, err + } + + // Set all metadata into kernfs.InodeAttrs. + if err := i.SetStat(ctx, fs, creds, vfs.SetStatOptions{ + Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, fusefs.devMinor), + }); err != nil { + return linux.Statx{}, err + } + + return statFromFUSEAttr(out.Attr, opts.Mask, fusefs.devMinor), nil +} diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index eaef2594d..610a7ed78 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -844,6 +844,13 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf } } if rp.Done() { + // Reject attempts to open mount root directory with O_CREAT. + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } + if mustCreate { + return nil, syserror.EEXIST + } return start.openLocked(ctx, rp, &opts) } @@ -856,6 +863,10 @@ afterTrailingSymlink: if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } + // Reject attempts to open directories with O_CREAT. + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } // Determine whether or not we need to create a file. parent.dirMu.Lock() child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) @@ -1071,10 +1082,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } creds := rp.Credentials() name := rp.Component() - // Filter file creation flags and O_LARGEFILE out; the create RPC already - // has the semantics of O_CREAT|O_EXCL, while some servers will choke on - // O_LARGEFILE. - createFlags := p9.OpenFlags(opts.Flags &^ (vfs.FileCreationFlags | linux.O_LARGEFILE)) + // We only want the access mode for creating the file. + createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) if err != nil { dirfile.close(ctx) diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index bf922c566..bd6caba06 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -552,7 +552,7 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin return syserror.ESPIPE } - // TODO(gvisor.dev/issue/2923): Implement Allocate for non-pipe hostfds. + // TODO(gvisor.dev/issue/3589): Implement Allocate for non-pipe hostfds. return syserror.EOPNOTSUPP } diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index e73732a6b..5cd428d64 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -26,6 +26,17 @@ go_template_instance( }, ) +go_template_instance( + name = "inode_refs", + out = "inode_refs.go", + package = "tmpfs", + prefix = "inode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "inode", + }, +) + go_library( name = "tmpfs", srcs = [ @@ -34,6 +45,7 @@ go_library( "directory.go", "filesystem.go", "fstree.go", + "inode_refs.go", "named_pipe.go", "regular_file.go", "socket_file.go", @@ -47,6 +59,7 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 065812065..a4864df53 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -320,7 +320,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf fs.mu.Lock() defer fs.mu.Unlock() if rp.Done() { - // Reject attempts to open directories with O_CREAT. + // Reject attempts to open mount root directory with O_CREAT. if rp.MustBeDir() { return nil, syserror.EISDIR } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 4681a2f52..de2af6d01 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -285,13 +285,10 @@ type inode struct { // fs is the owning filesystem. fs is immutable. fs *filesystem - // refs is a reference count. refs is accessed using atomic memory - // operations. - // // A reference is held on all inodes as long as they are reachable in the // filesystem tree, i.e. nlink is nonzero. This reference is dropped when // nlink reaches 0. - refs int64 + refs inodeRefs // xattrs implements extended attributes. // @@ -327,7 +324,6 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth panic("file type is required in FileMode") } i.fs = fs - i.refs = 1 i.mode = uint32(mode) i.uid = uint32(kuid) i.gid = uint32(kgid) @@ -339,6 +335,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth i.mtime = now // i.nlink initialized by caller i.impl = impl + i.refs.EnableLeakCheck() } // incLinksLocked increments i's link count. @@ -369,25 +366,15 @@ func (i *inode) decLinksLocked(ctx context.Context) { } func (i *inode) incRef() { - if atomic.AddInt64(&i.refs, 1) <= 1 { - panic("tmpfs.inode.incRef() called without holding a reference") - } + i.refs.IncRef() } func (i *inode) tryIncRef() bool { - for { - refs := atomic.LoadInt64(&i.refs) - if refs == 0 { - return false - } - if atomic.CompareAndSwapInt64(&i.refs, refs, refs+1) { - return true - } - } + return i.refs.TryIncRef() } func (i *inode) decRef(ctx context.Context) { - if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { + i.refs.DecRef(func() { i.watches.HandleDeletion(ctx) if regFile, ok := i.impl.(*regularFile); ok { // Release memory used by regFile to store data. Since regFile is @@ -395,9 +382,7 @@ func (i *inode) decRef(ctx context.Context) { // metadata. regFile.data.DropAll(regFile.memFile) } - } else if refs < 0 { - panic("tmpfs.inode.decRef() called without holding a reference") - } + }) } func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index fdc599477..9a7be3655 100644 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -72,6 +72,7 @@ const ( _TCR_T0SZ_VA48 = 64 - 48 // VA=48 _TCR_T1SZ_VA48 = 64 - 48 // VA=48 + _TCR_A1 = 1 << 22 _TCR_ASID16 = 1 << 36 _TCR_TBI0 = 1 << 37 diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 307a7645f..905712076 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -61,7 +61,6 @@ func (c *vCPU) initArchState() error { reg.addr = uint64(reflect.ValueOf(&data).Pointer()) regGet.addr = uint64(reflect.ValueOf(&dataGet).Pointer()) - vcpuInit.target = _KVM_ARM_TARGET_GENERIC_V8 vcpuInit.features[0] |= (1 << _KVM_ARM_VCPU_PSCI_0_2) if _, _, errno := syscall.RawSyscall( syscall.SYS_IOCTL, @@ -272,8 +271,16 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return c.fault(int32(syscall.SIGSEGV), info) case ring0.Vector(bounce): // ring0.VirtualizationException return usermem.NoAccess, platform.ErrContextInterrupt + case ring0.El0Sync_undef, + ring0.El1Sync_undef: + *info = arch.SignalInfo{ + Signo: int32(syscall.SIGILL), + Code: 1, // ILL_ILLOPC (illegal opcode). + } + info.SetAddr(switchOpts.Registers.Pc) // Include address. + return usermem.AccessType{}, platform.ErrContextSignal default: - return usermem.NoAccess, platform.ErrContextSignal + panic(fmt.Sprintf("unexpected vector: 0x%x", vector)) } } diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index d8a7bc2f9..9d29b7168 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -364,6 +364,9 @@ TEXT ·Halt(SB),NOSPLIT,$0 CMP RSV_REG, R9 BNE mmio_exit MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG) + + // Flush dcache. + WORD $0xd5087e52 // DC CISW mmio_exit: // Disable fpsimd. WORD $0xd5381041 // MRS CPACR_EL1, R1 @@ -381,6 +384,9 @@ mmio_exit: MRS VBAR_EL1, R9 MOVD R0, 0x0(R9) + // Flush dcahce. + WORD $0xd5087e52 // DC CISW + RET // HaltAndResume halts execution and point the pointer to the resume function. @@ -414,6 +420,7 @@ TEXT ·Current(SB),NOSPLIT,$0-8 // Prepare the vcpu environment for container application. TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 // Step1, save sentry context into memory. + MRS TPIDR_EL1, RSV_REG REGISTERS_SAVE(RSV_REG, CPU_REGISTERS) MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG) @@ -425,34 +432,13 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_R3(RSV_REG), R3 - // Step2, save SP_EL1, PSTATE into kernel temporary stack. - // switch to temporary stack. + // Step2, switch to temporary stack. LOAD_KERNEL_STACK(RSV_REG) - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - - SUB $STACK_FRAME_SIZE, RSP, RSP - MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R11 - MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R12 - STP (R11, R12), 16*0(RSP) - - MOVD CPU_REGISTERS+PTRACE_R11(RSV_REG), R11 - MOVD CPU_REGISTERS+PTRACE_R12(RSV_REG), R12 - // Step3, test user pagetable. - // If user pagetable is empty, trapped in el1_ia. - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - SWITCH_TO_APP_PAGETABLE(RSV_REG) - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - SWITCH_TO_KVM_PAGETABLE(RSV_REG) - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - - // If pagetable is not empty, recovery kernel temporary stack. - ADD $STACK_FRAME_SIZE, RSP, RSP - - // Step4, load app context pointer. + // Step3, load app context pointer. MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP - // Step5, prepare the environment for container application. + // Step4, prepare the environment for container application. // set sp_el0. MOVD PTRACE_SP(RSV_REG_APP), R1 WORD $0xd5184101 //MSR R1, SP_EL0 @@ -480,13 +466,13 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 LDP 16*0(RSP), (RSV_REG, RSV_REG_APP) ADD $STACK_FRAME_SIZE, RSP, RSP + ISB $15 ERET() // kernelExitToEl1 is the entrypoint for sentry in guest_el1. // Prepare the vcpu environment for sentry. TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 WORD $0xd538d092 //MRS TPIDR_EL1, R18 - MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R1 WORD $0xd5184001 //MSR R1, SPSR_EL1 @@ -503,6 +489,8 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 // Start is the CPU entrypoint. TEXT ·Start(SB),NOSPLIT,$0 + // Flush dcache. + WORD $0xd5087e52 // DC CISW // Init. MOVD $SCTLR_EL1_DEFAULT, R1 MSR R1, SCTLR_EL1 @@ -558,6 +546,7 @@ TEXT ·El1_sync(SB),NOSPLIT,$0 B el1_invalid el1_da: +el1_ia: WORD $0xd538d092 //MRS TPIDR_EL1, R18 WORD $0xd538601a //MRS FAR_EL1, R26 @@ -570,9 +559,6 @@ el1_da: B ·HaltAndResume(SB) -el1_ia: - B ·HaltAndResume(SB) - el1_sp_pc: B ·Shutdown(SB) @@ -644,9 +630,10 @@ el0_svc: MOVD $Syscall, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·HaltAndResume(SB) + B ·kernelExitToEl1(SB) el0_da: +el0_ia: WORD $0xd538d092 //MRS TPIDR_EL1, R18 WORD $0xd538601a //MRS FAR_EL1, R26 @@ -658,10 +645,10 @@ el0_da: MOVD $PageFault, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·HaltAndResume(SB) + MRS ESR_EL1, R3 + MOVD R3, CPU_ERROR_CODE(RSV_REG) -el0_ia: - B ·Shutdown(SB) + B ·kernelExitToEl1(SB) el0_fpsimd_acc: B ·Shutdown(SB) @@ -676,7 +663,10 @@ el0_sp_pc: B ·Shutdown(SB) el0_undef: - B ·Shutdown(SB) + MOVD $El0Sync_undef, R3 + MOVD R3, CPU_VECTOR_CODE(RSV_REG) + + B ·kernelExitToEl1(SB) el0_dbg: B ·Shutdown(SB) diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 42009dac0..d0afa1aaa 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -53,7 +53,6 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { - // Sanitize registers. regs := switchOpts.Registers regs.Pstate &= ^uint64(PsrFlagsClear) @@ -69,6 +68,5 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { vector = c.vecCode - // Perform the switch. return } diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index 78510ebed..6409d1d91 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -72,13 +72,14 @@ const ( ) const ( - mtNormal = 0x4 << 2 + mtDevicenGnRE = 0x1 << 2 + mtNormal = 0x4 << 2 ) const ( executeDisable = xn optionMask = 0xfff | 0xfff<<48 - protDefault = accessed | shared | mtNormal + protDefault = accessed | shared ) // MapOpts are x86 options. @@ -184,8 +185,10 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) { if opts.User { v |= user + v |= mtNormal } else { v = v &^ user + v |= mtDevicenGnRE // Strong order for the addresses with ring0.KernelStartAddress. } atomic.StoreUintptr((*uintptr)(p), v) } @@ -200,7 +203,7 @@ func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) { // This should never happen. panic("unaligned physical address!") } - v := addr | typeTable | protDefault + v := addr | typeTable | protDefault | mtNormal atomic.StoreUintptr((*uintptr)(p), v) } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 4f98ee2d5..0bfd6c1f4 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,7 +97,7 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) if netHeader.TransportProtocol() != header.TCPProtocolNumber { return false, false @@ -111,7 +111,7 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - tcpHeader := header.TCP(pkt.TransportHeader) + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { // There's no valid TCP header here, so we drop the packet immediately. return false, true diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3f20fc891..7ed05461d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,7 +94,7 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. @@ -110,7 +110,7 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) if len(udpHeader) < header.UDPMinimumSize { // There's no valid UDP header here, so we drop the packet immediately. return false, true diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 399b4f60c..42559bf69 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -144,6 +144,12 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback { return func(ctx context.Context) { + // Release references after completing the callback. + defer fd.DecRef(ctx) + if eventFD != nil { + defer eventFD.DecRef(ctx) + } + if aioCtx.Dead() { aioCtx.CancelPendingRequest() return @@ -169,8 +175,6 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use ev.Result = -int64(kernel.ExtractErrno(err, 0)) } - fd.DecRef(ctx) - // Queue the result for delivery. aioCtx.FinishRequest(ev) @@ -179,7 +183,6 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use // wake up. if eventFD != nil { eventFD.Impl().(*eventfd.EventFileDescription).Signal(1) - eventFD.DecRef(ctx) } } } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index d3c1197e3..dcafffe57 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -289,7 +289,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO { return syserror.EINVAL } - // TODO(jamieliu): FileDescriptionImpl.SetOAsync()? + // TODO(gvisor.dev/issue/1035): FileDescriptionImpl.SetOAsync()? const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK fd.flagsMu.Lock() if fd.asyncHandler != nil { @@ -301,7 +301,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede fd.asyncHandler.Unregister(fd) } } - fd.statusFlags = (oldFlags &^ settableFlags) | (flags & settableFlags) + atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags)) fd.flagsMu.Unlock() return nil } diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 9a3c5d6c3..ea0c5413d 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -65,6 +65,16 @@ func (v View) ToVectorisedView() VectorisedView { return NewVectorisedView(len(v), []View{v}) } +// IsEmpty returns whether v is of length zero. +func (v View) IsEmpty() bool { + return len(v) == 0 +} + +// Size returns the length of v. +func (v View) Size() int { + return len(v) +} + // VectorisedView is a vectorised version of View using non contiguous memory. // It supports all the convenience methods supported by View. // diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 1e5f5abf2..b769094dc 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -699,7 +699,7 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { } // ICMPv4Code creates a checker that checks the ICMPv4 Code field. -func ICMPv4Code(want byte) TransportChecker { +func ICMPv4Code(want header.ICMPv4Code) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() @@ -757,7 +757,7 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { } // ICMPv6Code creates a checker that checks the ICMPv6 Code field. -func ICMPv6Code(want byte) TransportChecker { +func ICMPv6Code(want header.ICMPv6Code) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 1a631b31a..be03fb086 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -54,6 +54,9 @@ const ( // ICMPv4Type is the ICMP type field described in RFC 792. type ICMPv4Type byte +// ICMPv4Code is the ICMP code field described in RFC 792. +type ICMPv4Code byte + // Typical values of ICMPv4Type defined in RFC 792. const ( ICMPv4EchoReply ICMPv4Type = 0 @@ -69,12 +72,13 @@ const ( ICMPv4InfoReply ICMPv4Type = 16 ) -// Values for ICMP code as defined in RFC 792. +// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. const ( - ICMPv4TTLExceeded = 0 - ICMPv4HostUnreachable = 1 - ICMPv4PortUnreachable = 3 - ICMPv4FragmentationNeeded = 4 + ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4HostUnreachable ICMPv4Code = 1 + ICMPv4ProtoUnreachable ICMPv4Code = 2 + ICMPv4PortUnreachable ICMPv4Code = 3 + ICMPv4FragmentationNeeded ICMPv4Code = 4 ) // Type is the ICMP type field. @@ -84,10 +88,10 @@ func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) } // Code is the ICMP code field. Its meaning depends on the value of Type. -func (b ICMPv4) Code() byte { return b[1] } +func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. -func (b ICMPv4) SetCode(c byte) { b[1] = c } +func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index a13b4b809..20b01d8f4 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -92,7 +92,6 @@ const ( // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. type ICMPv6Type byte -// Typical values of ICMPv6Type defined in RFC 4443. const ( ICMPv6DstUnreachable ICMPv6Type = 1 ICMPv6PacketTooBig ICMPv6Type = 2 @@ -110,18 +109,38 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// Values for ICMP destination unreachable code as defined in RFC 4443 section -// 3.1. +// ICMPv6Code is the ICMP code field described in RFC 4443. +type ICMPv6Code byte + +// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443 +// section 3.1. +const ( + ICMPv6NetworkUnreachable ICMPv6Code = 0 + ICMPv6Prohibited ICMPv6Code = 1 + ICMPv6BeyondScope ICMPv6Code = 2 + ICMPv6AddressUnreachable ICMPv6Code = 3 + ICMPv6PortUnreachable ICMPv6Code = 4 + ICMPv6Policy ICMPv6Code = 5 + ICMPv6RejectRoute ICMPv6Code = 6 +) + +// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3. const ( - ICMPv6NetworkUnreachable = 0 - ICMPv6Prohibited = 1 - ICMPv6BeyondScope = 2 - ICMPv6AddressUnreachable = 3 - ICMPv6PortUnreachable = 4 - ICMPv6Policy = 5 - ICMPv6RejectRoute = 6 + ICMPv6HopLimitExceeded ICMPv6Code = 0 + ICMPv6ReassemblyTimeout ICMPv6Code = 1 ) +// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4. +const ( + ICMPv6ErroneousHeader ICMPv6Code = 0 + ICMPv6UnknownHeader ICMPv6Code = 1 + ICMPv6UnknownOption ICMPv6Code = 2 +) + +// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use +// the code field. (Types not mentioned above.) +const ICMPv6UnusedCode ICMPv6Code = 0 + // Type is the ICMP type field. func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } @@ -129,10 +148,10 @@ func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) } // Code is the ICMP code field. Its meaning depends on the value of Type. -func (b ICMPv6) Code() byte { return b[1] } +func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } // SetCode sets the ICMP code field. -func (b ICMPv6) SetCode(c byte) { b[1] = c } +func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index d0d1efd0d..680eafd16 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -315,3 +315,12 @@ func IsV4MulticastAddress(addr tcpip.Address) bool { } return (addr[0] & 0xf0) == 0xe0 } + +// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback +// address (belongs to 127.0.0.1/8 subnet). +func IsV4LoopbackAddress(addr tcpip.Address) bool { + if len(addr) != IPv4AddressSize { + return false + } + return addr[0] == 0x7f +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 4f367fe4c..ea3823898 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -98,6 +98,9 @@ const ( // section 5. IPv6MinimumMTU = 1280 + // IPv6Loopback is the IPv6 Loopback address. + IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index e12a5929b..c95aef63c 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -274,7 +274,9 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: &stack.PacketBuffer{Data: vv}, + Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }), Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 507b44abc..10072eac1 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -37,5 +37,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index c18bb91fb..975309fc8 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -390,8 +390,7 @@ const ( func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ DstAddr: remote, Type: protocol, @@ -420,7 +419,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if gso != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen @@ -443,11 +442,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne builder.Add(vnetHdrBuf) } - builder.Add(pkt.Header.View()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Views() { builder.Add(v) } - return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } @@ -463,7 +460,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if pkt.GSOOptions != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen @@ -486,8 +483,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc var builder iovec.Builder builder.Add(vnetHdrBuf) - builder.Add(pkt.Header.View()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Views() { builder.Add(v) } iovecs := builder.Build() diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 7b995b85a..709f829c8 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -26,6 +26,7 @@ import ( "time" "unsafe" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,9 +44,36 @@ const ( ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents *stack.PacketBuffer + Raddr tcpip.LinkAddress + Proto tcpip.NetworkProtocolNumber + Contents *stack.PacketBuffer +} + +type packetContents struct { + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View + Data buffer.View +} + +func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { + t.Helper() + if diff := cmp.Diff( + want, got, + cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { + if pk == nil { + return nil + } + return &packetContents{ + LinkHeader: pk.LinkHeader().View(), + NetworkHeader: pk.NetworkHeader().View(), + TransportHeader: pk.TransportHeader().View(), + Data: pk.Data.ToView(), + } + }), + ); diff != "" { + t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) + } } type context struct { @@ -159,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u RemoteLinkAddress: raddr, } - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) + // Build payload. + payload := buffer.NewView(plen) + if _, err := rand.Read(payload); err != nil { + t.Fatalf("rand.Read(payload): %s", err) } - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) + // Build packet buffer. + const netHdrLen = 100 + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, + Data: payload.ToVectorisedView(), + }) + pkt.Hash = hash + + // Build header. + b := pkt.NetworkHeader().Push(netHdrLen) + if _, err := rand.Read(b); err != nil { + t.Fatalf("rand.Read(b): %s", err) } - want := append(hdr.View(), payload...) + + // Write. + want := append(append(buffer.View(nil), b...), payload...) var gso *stack.GSO if gsoMaxSize != 0 { gso = &stack.GSO{ @@ -183,11 +220,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - Hash: hash, - }); err != nil { + if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -296,13 +329,14 @@ func TestPreserveSrcAddress(t *testing.T) { LocalLinkAddress: baddr, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.VectorisedView{}, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.VectorisedView{}, + }) + if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -331,24 +365,25 @@ func TestDeliverPacket(t *testing.T) { defer c.cleanup() // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) + all := make([]byte, plen) + if _, err := rand.Read(all); err != nil { + t.Fatalf("rand.Read(all): %s", err) } - - var hdr header.Ethernet - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr = make(header.Ethernet, header.EthernetMinimumSize) + // Make it look like an IPv4 packet. + all[0] = 0x40 + + wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.NewViewFromBytes(all).ToVectorisedView(), + }) + if eth { + hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, Type: proto, }) - all = append(hdr, b...) + all = append(hdr, all...) } // Write packet via the file descriptor. @@ -360,24 +395,15 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: &stack.PacketBuffer{ - Data: buffer.View(b).ToVectorisedView(), - LinkHeader: buffer.View(hdr), - }, + Raddr: raddr, + Proto: proto, + Contents: wantPkt, } if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - // want.contents.Data will be a single - // view, so make pi do the same for the - // DeepEqual check. - pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + want.Proto = header.IPv4ProtocolNumber + want.Raddr = "" } + checkPacketInfoEqual(t, pi, want) case <-time.After(10 * time.Second): t.Fatalf("Timed out waiting for packet") } @@ -572,8 +598,8 @@ func TestDispatchPacketFormat(t *testing.T) { t.Fatalf("len(sink.pkts) = %d, want %d", got, want) } pkt := sink.pkts[0] - if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want { - t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want) + if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { + t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) } if got, want := pkt.Data.Size(), 4; got != want { t.Errorf("pkt.Data.Size() = %d, want %d", got, want) diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 2dfd29aa9..c475dda20 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -18,6 +18,7 @@ package fdbased import ( "encoding/binary" + "fmt" "syscall" "golang.org/x/sys/unix" @@ -170,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(pkt) + eth := header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -190,10 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } } - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ - Data: buffer.View(pkt).ToVectorisedView(), - LinkHeader: buffer.View(eth), + pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(pkt).ToVectorisedView(), }) + if d.e.hdrSize > 0 { + if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok { + panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize)) + } + } + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index d8f2504b3..8c3ca86d6 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.allocateViews(BufConfig) n, err := rawfile.BlockingReadv(d.fd, d.iovecs) - if err != nil { + if n == 0 || err != nil { return false, err } if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { // isn't used and it isn't in a view. n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(n, BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,13 +143,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(n, BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. @@ -268,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(k, int(n), BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -295,12 +298,6 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(k, int(n), BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 781cdd317..38aa694e4 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -77,16 +77,16 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) + // Construct data as the unparsed portion for the loopback packet. + data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt) return nil } @@ -98,18 +98,17 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) if !ok { // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } linkHeader := header.Ethernet(hdr) - vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ - Data: vv, - LinkHeader: buffer.View(linkHeader), - }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) return nil } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 0744f66d6..3e4afcdad 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -46,14 +46,14 @@ func TestInjectableEndpointRawDispatch(t *testing.T) { func TestInjectableEndpointDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) + pkt.TransportHeader().Push(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -67,13 +67,14 @@ func TestInjectableEndpointDispatch(t *testing.T) { func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewView(0).ToVectorisedView(), + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewView(0).ToVectorisedView(), }) + pkt.TransportHeader().Push(1)[0] = 0xFA + packetRoute := stack.Route{RemoteAddress: dstIP} + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go index 7d9249c1c..c1f9d308c 100644 --- a/pkg/tcpip/link/nested/nested_test.go +++ b/pkg/tcpip/link/nested/nested_test.go @@ -87,7 +87,7 @@ func TestNestedLinkEndpoint(t *testing.T) { t.Error("After attach, nestedEP.IsAttached() = false, want = true") } - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 1 { t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) } @@ -101,7 +101,7 @@ func TestNestedLinkEndpoint(t *testing.T) { } disp.count = 0 - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 0 { t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 467083239..fc1e34fc7 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -199,7 +199,7 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // TODO(gvisor.dev/issue/3267/): Queue these packets as well once + // TODO(gvisor.dev/issue/3267): Queue these packets as well once // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 507c76b76..7fb8a6c49 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -186,8 +186,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // AddHeader implements stack.LinkEndpoint.AddHeader. func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ DstAddr: remote, Type: protocol, @@ -207,10 +206,10 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) - v := pkt.Data.ToView() + views := pkt.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(pkt.Header.View(), v) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -227,10 +226,10 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - v := vv.ToView() + views := vv.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(v, buffer.View{}) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -276,16 +275,18 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { rxb[i].Size = e.bufferSize } - if n < header.EthernetMinimumSize { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { continue } + eth := header.Ethernet(hdr) // Send packet up the stack. - eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ - Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), - LinkHeader: buffer.View(eth), - }) + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 8f3cd9449..22d5c97f1 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -266,21 +266,23 @@ func TestSimpleSend(t *testing.T) { for iters := 1000; iters > 0; iters-- { func() { + hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000) + // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) + hdrBuf := buffer.NewView(hdrLen) randomFill(hdrBuf) - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) + data := buffer.NewView(dataLen) + randomFill(data) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()), + Data: data.ToVectorisedView(), + }) + copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -317,7 +319,7 @@ func TestSimpleSend(t *testing.T) { // Compare contents skipping the ethernet header added by the // endpoint. - merged := append(hdrBuf, buf...) + merged := append(hdrBuf, data...) if uint32(len(contents)) < pi.Size { t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) } @@ -344,14 +346,14 @@ func TestPreserveSrcAddressInSend(t *testing.T) { LocalLinkAddress: newLocalLinkAddress, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + ReserveHeaderBytes: header.EthernetMinimumSize, + }) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -403,12 +405,12 @@ func TestFillTxQueue(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -422,11 +424,11 @@ func TestFillTxQueue(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -450,11 +452,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -473,11 +475,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -491,11 +493,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -517,11 +519,11 @@ func TestFillTxMemory(t *testing.T) { // we fill the memory. ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -536,11 +538,11 @@ func TestFillTxMemory(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), }) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -564,11 +566,11 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Each packet is uses up one buffer, so write as many as possible // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -579,23 +581,22 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write a two-buffer packet. It must fail. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: uu, - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buffer.NewView(bufferSize).ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } // Attempt to write the one-buffer packet again. It must succeed. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index 6b8d7859d..44f421c2d 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "syscall" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -76,9 +77,9 @@ func (t *tx) cleanup() { syscall.Munmap(t.data) } -// transmit sends a packet made up of up to two buffers. Returns a boolean that -// specifies whether the packet was successfully transmitted. -func (t *tx) transmit(a, b []byte) bool { +// transmit sends a packet made of bufs. Returns a boolean that specifies +// whether the packet was successfully transmitted. +func (t *tx) transmit(bufs ...buffer.View) bool { // Pull completions from the tx queue and add their buffers back to the // pool so that we can reuse them. for { @@ -93,7 +94,10 @@ func (t *tx) transmit(a, b []byte) bool { } bSize := t.bufs.entrySize - total := uint32(len(a) + len(b)) + total := uint32(0) + for _, data := range bufs { + total += uint32(len(data)) + } bufCount := (total + bSize - 1) / bSize // Allocate enough buffers to hold all the data. @@ -115,7 +119,7 @@ func (t *tx) transmit(a, b []byte) bool { // Copy data into allocated buffers. nBuf := buf var dBuf []byte - for _, data := range [][]byte{a, b} { + for _, data := range bufs { for len(data) > 0 { if len(dBuf) == 0 { dBuf = t.data[nBuf.Offset:][:nBuf.Size] diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 509076643..4fb127978 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -134,7 +134,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { - totalLength := pkt.Header.UsedLength() + pkt.Data.Size() + totalLength := pkt.Size() length := totalLength if max := int(e.maxPCAPLen); length > max { length = max @@ -155,12 +155,11 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw length -= n } } - write(pkt.Header.View()) - for _, view := range pkt.Data.Views() { + for _, v := range pkt.Views() { if length == 0 { break } - write(view) + write(v) } } } @@ -185,9 +184,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ + e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) return e.Endpoint.WriteRawPacket(vv) } @@ -201,12 +200,8 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var fragmentOffset uint16 var moreFragments bool - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) + // Examine the packet using a new VV. Backing storage must not be written. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) switch protocol { case header.IPv4ProtocolNumber: diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 22b0a12bd..3b1510a33 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -215,12 +215,11 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := &stack.PacketBuffer{ - Data: buffer.View(data).ToVectorisedView(), - } - if ethHdr != nil { - pkt.LinkHeader = buffer.View(ethHdr) - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: len(ethHdr), + Data: buffer.View(data).ToVectorisedView(), + }) + copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr) endpoint.InjectLinkAddr(protocol, remote, pkt) return dataLen, nil } @@ -265,21 +264,22 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { return nil, false } // Ethernet header (TAP only). if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. - if info.Pkt.LinkHeader == nil { + if info.Pkt.LinkHeader().View().IsEmpty() { d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } - vv.AppendView(info.Pkt.LinkHeader) + vv.AppendView(info.Pkt.LinkHeader().View()) } // Append upper headers. - vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + vv.AppendView(info.Pkt.NetworkHeader().View()) + vv.AppendView(info.Pkt.TransportHeader().View()) // Append data payload. vv.Append(info.Pkt.Data) @@ -361,8 +361,7 @@ func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip. if !e.isTap { return } - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr := &header.EthernetFields{ SrcAddr: local, DstAddr: remote, diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index c448a888f..94827fc56 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -104,21 +104,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -135,21 +135,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 31a242482..1ad788a17 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,7 +99,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.ARP(pkt.NetworkHeader) + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return } @@ -110,17 +110,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } - hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - packet := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + }) + packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -168,17 +168,17 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetBroadcastAddress } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) - h := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + }) + h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) copy(h.HardwareAddressSender(), linkEP.LinkAddress()) copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -210,12 +210,10 @@ func (*protocol) Wait() {} // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.ARPSize) + _, ok = pkt.NetworkHeader().Consume(header.ARPSize) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(header.ARPSize) return 0, false, true } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index a35a64a0f..c2c3e6891 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -106,9 +106,9 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { @@ -118,9 +118,9 @@ func TestDirectRequest(t *testing.T) { if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pi.Pkt.Header.View()) + rep := header.ARP(pi.Pkt.NetworkHeader().View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 615bae648..491d936a1 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -156,13 +156,13 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(pkt.Header.View()) + h := header.IPv4(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(pkt.Header.View()) + h := header.IPv6(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() @@ -243,8 +243,11 @@ func TestIPv4Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -260,10 +263,7 @@ func TestIPv4Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -303,9 +303,13 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -317,7 +321,7 @@ func TestIPv4ReceiveControl(t *testing.T) { name string expectedCount int fragmentOffset uint16 - code uint8 + code header.ICMPv4Code expectedTyp stack.ControlType expectedExtra uint32 trunc int @@ -455,17 +459,25 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag1.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag2.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -485,8 +497,11 @@ func TestIPv6Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -502,10 +517,7 @@ func TestIPv6Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -545,9 +557,13 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -563,7 +579,7 @@ func TestIPv6ReceiveControl(t *testing.T) { expectedCount int fragmentOffset *uint16 typ header.ICMPv6Type - code uint8 + code header.ICMPv6Code expectedTyp stack.ControlType expectedExtra uint32 trunc int @@ -673,11 +689,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // becomes Data. func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { v := view[:len(view)-trunc] - if len(v) < netHdrLen { - return &stack.PacketBuffer{Data: v.ToVectorisedView()} - } - return &stack.PacketBuffer{ - NetworkHeader: v[:netHdrLen], - Data: v[netHdrLen:].ToVectorisedView(), - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + _, _ = pkt.NetworkHeader().Consume(netHdrLen) + return pkt } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 94803a359..067d770f3 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -89,12 +89,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { return } + // Make a copy of data before pkt gets sent to raw socket. + // DeliverTransportPacket will take ownership of pkt. + replyData := pkt.Data.Clone(nil) + replyData.TrimFront(header.ICMPv4MinimumSize) + // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ - Data: pkt.Data.Clone(nil), - NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), - }) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) remoteLinkAddr := r.RemoteLinkAddress @@ -116,24 +118,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // Use the remote link address from the incoming packet. r.ResolveWith(remoteLinkAddr) - vv := pkt.Data.Clone(nil) - vv.TrimFront(header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv4EchoReply) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) + // Prepare a reply packet. + icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize) + copy(icmpHdr, h) + icmpHdr.SetType(header.ICMPv4EchoReply) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0))) + dataVV := buffer.View(icmpHdr).ToVectorisedView() + dataVV.Append(replyData) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: dataVV, + }) + + // Send out the reply packet. sent := stats.ICMP.V4PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: vv, - TransportHeader: buffer.View(pkt), - }); err != nil { + }, replyPkt); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6c4f0ae3e..3cd48ceb3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -21,7 +21,6 @@ package ipv4 import ( - "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -127,14 +126,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in pkt.Header but does not -// assume that only the IP header is in pkt.Header. It assumes that the input -// packet's stated length matches the length of the header+payload. mtu -// includes the IP header and options. This does not support the DontFragment -// IP flag. +// write. It assumes that the IP header is already present in pkt.NetworkHeader. +// pkt.TransportHeader may be set. mtu includes the IP header and options. This +// does not support the DontFragment IP flag. func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.Header.View()) + ip := header.IPv4(pkt.NetworkHeader().View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -148,85 +145,84 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := pkt.Header.AvailableLength() + + // Keep the length reserved for link-layer, we need to create fragments with + // the same reserved length. + reservedForLink := pkt.AvailableHeaderBytes() + + // Destroy the packet, pull all payloads out for fragmentation. + transHeader, data := pkt.TransportHeader().View(), pkt.Data + + // Where possible, the first fragment that is sent has the same + // number of bytes reserved for header as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. + transFitsFirst := len(transHeader) <= innerMTU + for i := 0; i < n; i++ { - // Where possible, the first fragment that is sent has the same - // pkt.Header.UsedLength() as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - h := ip - if i > 0 { - pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) - copy(h, ip[:ip.HeaderLength()]) + reserve := reservedForLink + int(ip.HeaderLength()) + if i == 0 && transFitsFirst { + // Reserve for transport header if it's going to be put in the first + // fragment. + reserve += len(transHeader) + } + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: reserve, + }) + fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + + // Copy data for the fragment. + avail := innerMTU + + if n := len(transHeader); n > 0 { + if n > avail { + n = avail + } + if i == 0 && transFitsFirst { + copy(fragPkt.TransportHeader().Push(n), transHeader) + } else { + fragPkt.Data.AppendView(transHeader[:n:n]) + } + transHeader = transHeader[n:] + avail -= n } + + if avail > 0 { + n := data.Size() + if n > avail { + n = avail + } + data.ReadToVV(&fragPkt.Data, n) + avail -= n + } + + copied := uint16(innerMTU - avail) + + // Set lengths in header and calculate checksum. + h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) + copy(h, ip) if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + copied) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) - offset += uint16(innerMTU) - if i > 0 { - newPayload := pkt.Data.Clone(nil) - newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayload.Size()) - continue - } - // Special handling for the first fragment because it comes - // from the header. - if outerMTU >= pkt.Header.UsedLength() { - // This fragment can fit all of pkt.Header and possibly - // some of pkt.Data, too. - newPayload := pkt.Data.Clone(nil) - newPayloadLength := outerMTU - pkt.Header.UsedLength() - newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayloadLength) - } else { - // The fragment is too small to fit all of pkt.Header. - startOfHdr := pkt.Header - startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) - emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of pkt.Header into the pkt.Data - // that remains to be sent. - restOfHdr := pkt.Header.View()[outerMTU:] - tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(pkt.Data) - pkt.Data = tmp + offset += copied + + // Send out the fragment. + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + return err } + r.Stats().IP.PacketsSent.Increment() } return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payloadSize) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + length := uint16(pkt.Size()) // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. @@ -242,17 +238,16 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - return ip + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pkt, params) - nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) ipt := e.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. @@ -265,7 +260,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // only NATted packets, but removing this check short circuits broadcasts // before they are sent out to other hosts. if pkt.NatDone { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) @@ -282,7 +277,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { @@ -302,8 +297,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pkt, params) pkt = pkt.Next() } @@ -328,7 +322,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() @@ -397,17 +391,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu r.Stats().IP.PacketsSent.Increment() - ip = ip[:ip.HeaderLength()] - pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) - pkt.Data.TrimFront(int(ip.HeaderLength())) return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv4(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + h := header.IPv4(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -421,7 +412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } if h.More() || h.FragmentOffset() != 0 { - if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { + if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -465,7 +456,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - pkt.NetworkHeader.CapLength(int(h.HeaderLength())) e.handleICMP(r, pkt) return } @@ -555,14 +545,19 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } ipHdr := header.IPv4(hdr) - // If there are options, pull those into hdr as well. - if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(headerLen) - if !ok { - panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) - } - ipHdr = header.IPv4(hdr) + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return 0, false, false } + ipHdr = header.IPv4(hdr) // If this is a fragment, don't bother parsing the transport header. parseTransportHeader := true @@ -570,8 +565,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu parseTransportHeader = false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) return ipHdr.TransportProtocol(), parseTransportHeader, true } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ded97ac64..afd3ac06d 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,6 +17,7 @@ package ipv4_test import ( "bytes" "encoding/hex" + "fmt" "math/rand" "testing" @@ -91,15 +92,11 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much +// makeRandPkt generates a randomize packet. hdrLength indicates how much // data should already be in the header before WritePacket. extraLength // indicates how much extra space should be in the header. The payload is made // from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - +func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer { var views []buffer.View totalLength := 0 for _, s := range viewSizes { @@ -108,8 +105,16 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. views = append(views, newView) totalLength += s } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLength + extraLength, + Data: buffer.NewVectorisedView(totalLength, views), + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt } // comparePayloads compared the contents of all the packets against the contents @@ -117,9 +122,9 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Data.ToView()...) + source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) + vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) + source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -132,8 +137,7 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI var reassembledPayload []byte for i, packet := range packets { // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Data) + allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -144,12 +148,22 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + if i == 0 { + got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() + // sourcePacketInfo does not have NetworkHeader added, simulate one. + want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() + // Check that it kept the transport header in packet.TransportHeader if + // it fits in the first fragment. + if want < int(mtu) && got != want { + t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + } } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { + if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } + if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { + t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { @@ -281,21 +295,14 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := &stack.PacketBuffer{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), - } + pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + source := pkt.Clone() c := buildContext(t, nil, ft.mtu) err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) if err != nil { t.Errorf("err got %v, want %v", err, nil) } @@ -340,16 +347,13 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) + pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -468,9 +472,9 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), - }) + })) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { @@ -855,9 +859,9 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ded91d83a..39ae19295 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -83,7 +83,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } h := header.ICMPv6(v) - iph := header.IPv6(pkt.NetworkHeader) + iph := header.IPv6(pkt.NetworkHeader().View()) // Validate ICMPv6 checksum before processing the packet. // @@ -276,8 +276,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme optsSerializer := header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), + }) + packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(solicited) @@ -293,9 +295,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.Dropped.Increment() return } @@ -384,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { received.Invalid.Increment() return @@ -409,16 +409,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Use the link address from the source of the original packet. r.ResolveWith(remoteLinkAddr) - pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, + Data: pkt.Data, + }) + packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: pkt.Data, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -539,17 +538,19 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - copy(pkt[icmpV6OptOffset-len(addr):], addr) - pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - length := uint16(hdr.UsedLength()) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, + }) + icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + icmpHdr.SetType(header.ICMPv6NeighborSolicit) + copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) + icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr + icmpHdr[icmpV6LengthOffset] = 1 + copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -559,9 +560,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index f86aaed1d..2a2f7de01 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -183,7 +183,11 @@ func TestICMPCounts(t *testing.T) { } handleIPv6Payload := func(icmp header.ICMPv6) { - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -191,10 +195,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: buffer.View(icmp).ToVectorisedView(), - }) + ep.HandlePacket(&r, pkt) } for _, typ := range types { @@ -323,12 +324,10 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. pi, _ := args.src.ReadContext(context.Background()) { - views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} - size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ - Data: vv, + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), }) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) } if pi.Proto != ProtocolNumber { @@ -340,7 +339,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } - ipv6 := header.IPv6(pi.Pkt.Header.View()) + // Pull the full payload since network header. Needed for header.IPv6 to + // extract its payload. + ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -558,9 +559,10 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -719,12 +721,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { icmpSize := size + payloadSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) + icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) + icmpHdr.SetType(typ) + payloadFn(icmpHdr.Payload()) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -735,9 +737,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -895,14 +898,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + icmpHdr := header.ICMPv6(hdr.Prepend(size)) + icmpHdr.SetType(typ) payload := buffer.NewView(payloadSize) payloadFn(payload) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -913,9 +916,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4a0b53c45..0ade655b2 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -99,9 +99,9 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { - length := uint16(hdr.UsedLength() + payloadSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(params.Protocol), @@ -110,25 +110,20 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - return ip + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pkt, params) if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // The inbound path expects an unparsed packet. + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) loopedR.Release() } @@ -150,8 +145,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pb := pkts.Front(); pb != nil; pb = pb.Next() { - ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) - pb.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pb, params) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -169,8 +163,8 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv6(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + h := header.IPv6(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -179,8 +173,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. - vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() - vv.AppendView(pkt.TransportHeader) + vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader().View()) vv.Append(pkt.Data) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false @@ -408,7 +402,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(len(pkt.TransportHeader)) + extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -579,16 +573,14 @@ traverseExtensions: } } - // Put the IPv6 header with extensions in pkt.NetworkHeader. - hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) if !ok { panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) } ipHdr = header.IPv6(hdr) - - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber return nextHdr, foundNext, true } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 3d65814de..081afb051 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,9 +65,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().ICMP.V6PacketsReceived @@ -123,9 +123,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stat := s.Stats().UDP.PacketsReceived @@ -637,9 +637,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().UDP.PacketsReceived @@ -1469,9 +1469,9 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 64239ce9a..2efa82e60 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,9 +136,9 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -380,9 +380,9 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) if test.nsInvalid { if got := invalid.Value(); got != 1 { @@ -410,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.naSrc), checker.DstAddr(test.naDst), checker.TTL(header.NDPHopLimit), @@ -497,9 +497,9 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -560,7 +560,11 @@ func TestNDPValidation(t *testing.T) { nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + Data: payload.ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, @@ -571,10 +575,7 @@ func TestNDPValidation(t *testing.T) { if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - ep.HandlePacket(r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: payload.ToVectorisedView(), - }) + ep.HandlePacket(r, pkt) } var tllData [header.NDPLinkLayerAddressSize]byte @@ -636,7 +637,7 @@ func TestNDPValidation(t *testing.T) { name string atomicFragment bool hopLimit uint8 - code uint8 + code header.ICMPv6Code valid bool }{ { @@ -729,7 +730,7 @@ func TestRouterAdvertValidation(t *testing.T) { name string src tcpip.Address hopLimit uint8 - code uint8 + code header.ICMPv6Code ndpPayload []byte expectedSuccess bool }{ @@ -885,9 +886,9 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) if got := rxRA.Value(); got != 1 { t.Fatalf("got rxRA = %d, want = 1", got) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bfc7a0c7c..900938dd1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -57,6 +57,7 @@ go_library( "conntrack.go", "dhcpv6configurationfromndpra_string.go", "forwarder.go", + "headertype_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -143,6 +144,7 @@ go_test( "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", + "packet_buffer_test.go", ], library = ":stack", deps = [ diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 470c265aa..7dd344b4f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -199,12 +199,12 @@ type bucket struct { func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader) - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol } @@ -344,8 +344,8 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -377,8 +377,8 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -396,8 +396,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -423,7 +422,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return false } @@ -433,8 +432,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou return true } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return false } @@ -455,7 +454,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) return false } @@ -474,7 +473,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return } @@ -486,7 +485,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) ct.insertConn(conn) } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c962693f5..944f622fd 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -75,7 +75,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -97,7 +97,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) b[dstAddrOffset] = r.RemoteAddress[0] b[srcAddrOffset] = f.id.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) @@ -144,13 +144,11 @@ func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Add } func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = netHeader - pkt.Data.TrimFront(fwdTestNetHeaderLen) - return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true + return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -290,7 +288,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: &PacketBuffer{Data: vv}, + Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), } select { @@ -382,9 +380,9 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -419,9 +417,9 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -450,9 +448,9 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) select { case <-ep2.C: @@ -480,17 +478,17 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -500,8 +498,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -529,9 +527,9 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < 2; i++ { @@ -543,8 +541,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -575,9 +573,9 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < maxPendingPacketsPerResolution; i++ { @@ -589,13 +587,14 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) } - seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). - if !ok { - t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) } + seqNumBuf := b[fwdTestNetHeaderLen:] // The first 5 packets should not be forwarded so the sequence number should // start with 5. @@ -632,9 +631,9 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < maxPendingResolutions; i++ { @@ -648,8 +647,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go new file mode 100644 index 000000000..5efddfaaf --- /dev/null +++ b/pkg/tcpip/stack/headertype_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type headerType ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[linkHeader-0] + _ = x[networkHeader-1] + _ = x[transportHeader-2] + _ = x[numHeaderType-3] +} + +const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType" + +var _headerType_index = [...]uint8{0, 10, 23, 38, 51} + +func (i headerType) String() string { + if i < 0 || i >= headerType(len(_headerType_index)-1) { + return "headerType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _headerType_name[_headerType_index[i]:_headerType_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 110ba073d..c37da814f 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -394,7 +394,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index dc88033c7..5f1b2af64 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -99,7 +99,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso } // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { return RuleDrop, 0 } @@ -118,17 +118,16 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) switch protocol := netHeader.TransportProtocol(); protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) udpHeader.SetDestinationPort(rt.MinPort) // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 5174e639c..93567806b 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -746,12 +746,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmpData.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) ns.SetTargetAddress(addr) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, @@ -759,7 +763,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() return err @@ -1897,12 +1901,16 @@ func (ndp *ndpState) startSolicitingRouters() { } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) - pkt := header.ICMPv6(hdr.Prepend(payloadSize)) - pkt.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(payloadSize)) + icmpData.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(icmpData.NDPPayload()) rs.Options().Serialize(optsSerializer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, @@ -1910,7 +1918,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5d286ccbc..21bf53010 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -541,7 +541,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), @@ -550,8 +550,8 @@ func TestDADResolve(t *testing.T) { checker.NDPNSOptions(nil), )) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } }) @@ -667,9 +667,10 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(header.IPv6ProtocolNumber, pkt) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -1024,7 +1025,9 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) } // raBufWithOpts returns a valid NDP Router Advertisement with options. @@ -5134,16 +5137,15 @@ func TestRouterSolicitation(t *testing.T) { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } - checker.IPv6(t, - p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } waitForNothing := func(timeout time.Duration) { @@ -5288,7 +5290,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index eaaf756cd..2315ea5b9 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1299,7 +1299,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader) + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1401,24 +1401,19 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this - // by having lower layers explicity write each header instead of just - // pkt.Header. - // pkt may have set its NetworkHeader and TransportHeader. If we're - // forwarding, we'll have to copy them into pkt.Header. - pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) - if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) - } - if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) - } + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } @@ -1443,34 +1438,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. - if pkt.TransportHeader == nil { + if pkt.TransportHeader().View().IsEmpty() { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - pkt.TransportHeader = transHeader - pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) } } - if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { + if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index a70792b50..0870c8d9c 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -311,7 +311,9 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ + Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), + })) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 5d6865e35..17b8beebb 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,16 +14,43 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) +type headerType int + +const ( + linkHeader headerType = iota + networkHeader + transportHeader + numHeaderType +) + +// PacketBufferOptions specifies options for PacketBuffer creation. +type PacketBufferOptions struct { + // ReserveHeaderBytes is the number of bytes to reserve for headers. Total + // number of bytes pushed onto the headers must not exceed this value. + ReserveHeaderBytes int + + // Data is the initial unparsed data for the new packet. If set, it will be + // owned by the new packet. + Data buffer.VectorisedView +} + // A PacketBuffer contains all the data of a network packet. // // As a PacketBuffer traverses up the stack, it may be necessary to pass it to -// multiple endpoints. Clone() should be called in such cases so that -// modifications to the Data field do not affect other copies. +// multiple endpoints. +// +// The whole packet is expected to be a series of bytes in the following order: +// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be +// empty. Use of PacketBuffer in any other order is unsupported. +// +// PacketBuffer must be created with NewPacketBuffer. type PacketBuffer struct { _ sync.NoCopy @@ -31,36 +58,32 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. For inbound packets, it also - // holds the headers, which are consumed as the packet moves up the - // stack. Headers are guaranteed not to be split across views. + // Data holds the payload of the packet. + // + // For inbound packets, Data is initially the whole packet. Then gets moved to + // headers via PacketHeader.Consume, when the packet is being parsed. + // + // For outbound packets, Data is the innermost layer, defined by the protocol. + // Headers are pushed in front of it via PacketHeader.Push. // - // The bytes backing Data are immutable, but Data itself may be trimmed - // or otherwise modified. + // The bytes backing Data are immutable, a.k.a. users shouldn't write to its + // backing storage. Data buffer.VectorisedView - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. - Header buffer.Prependable + // headers stores metadata about each header. + headers [numHeaderType]headerInfo - // These fields are used by both inbound and outbound packets. They - // typically overlap with the Data and Header fields. - // - // The bytes backing these views are immutable. Each field may be nil - // if either it has not been set yet or no such header exists (e.g. - // packets sent via loopback may not have a link header). + // header is the internal storage for outbound packets. Headers will be pushed + // (prepended) on this storage as the packet is being constructed. // - // These fields may be Views into other slices (either Data or Header). - // SR dosen't support this, so deep copies are necessary in some cases. - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View + // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and + // data are held in the same underlying buffer storage. + header buffer.Prependable + + // NetworkProtocol is only valid when NetworkHeader is set. + // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol + // numbers in registration APIs that take a PacketBuffer. + NetworkProtocolNumber tcpip.NetworkProtocolNumber // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. @@ -72,9 +95,8 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. - EgressRoute *Route - GSOOptions *GSO - NetworkProtocolNumber tcpip.NetworkProtocolNumber + EgressRoute *Route + GSOOptions *GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. @@ -85,20 +107,137 @@ type PacketBuffer struct { PktType tcpip.PacketType } -// Clone makes a copy of pk. It clones the Data field, which creates a new -// VectorisedView but does not deep copy the underlying bytes. -// -// Clone also does not deep copy any of its other fields. +// NewPacketBuffer creates a new PacketBuffer with opts. +func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { + pk := &PacketBuffer{ + Data: opts.Data, + } + if opts.ReserveHeaderBytes != 0 { + pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + } + return pk +} + +// ReservedHeaderBytes returns the number of bytes initially reserved for +// headers. +func (pk *PacketBuffer) ReservedHeaderBytes() int { + return pk.header.UsedLength() + pk.header.AvailableLength() +} + +// AvailableHeaderBytes returns the number of bytes currently available for +// headers. This is relevant to PacketHeader.Push method only. +func (pk *PacketBuffer) AvailableHeaderBytes() int { + return pk.header.AvailableLength() +} + +// LinkHeader returns the handle to link-layer header. +func (pk *PacketBuffer) LinkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: linkHeader, + } +} + +// NetworkHeader returns the handle to network-layer header. +func (pk *PacketBuffer) NetworkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: networkHeader, + } +} + +// TransportHeader returns the handle to transport-layer header. +func (pk *PacketBuffer) TransportHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: transportHeader, + } +} + +// HeaderSize returns the total size of all headers in bytes. +func (pk *PacketBuffer) HeaderSize() int { + // Note for inbound packets (Consume called), headers are not stored in + // pk.header. Thus, calculation of size of each header is needed. + var size int + for i := range pk.headers { + size += len(pk.headers[i].buf) + } + return size +} + +// Size returns the size of packet in bytes. +func (pk *PacketBuffer) Size() int { + return pk.HeaderSize() + pk.Data.Size() +} + +// Views returns the underlying storage of the whole packet. +func (pk *PacketBuffer) Views() []buffer.View { + // Optimization for outbound packets that headers are in pk.header. + useHeader := true + for i := range pk.headers { + if !canUseHeader(&pk.headers[i]) { + useHeader = false + break + } + } + + dataViews := pk.Data.Views() + + var vs []buffer.View + if useHeader { + vs = make([]buffer.View, 0, 1+len(dataViews)) + vs = append(vs, pk.header.View()) + } else { + vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) + for i := range pk.headers { + if v := pk.headers[i].buf; len(v) > 0 { + vs = append(vs, v) + } + } + } + return append(vs, dataViews...) +} + +func canUseHeader(h *headerInfo) bool { + // h.offset will be negative if the header was pushed in to prependable + // portion, or doesn't matter when it's empty. + return len(h.buf) == 0 || h.offset < 0 +} + +func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + } + h.buf = buffer.View(pk.header.Prepend(size)) + h.offset = -pk.header.UsedLength() + return h.buf +} + +func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) + } + v, ok := pk.Data.PullUp(size) + if !ok { + return + } + pk.Data.TrimFront(size) + h.buf = v + return h.buf, true +} + +// Clone makes a shallow copy of pk. // -// FIXME(b/153685824): Data gets copied but not other header references. +// Clone should be called in such cases so that no modifications is done to +// underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - return &PacketBuffer{ + newPk := &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, Data: pk.Data.Clone(nil), - Header: pk.Header, - LinkHeader: pk.LinkHeader, - NetworkHeader: pk.NetworkHeader, - TransportHeader: pk.TransportHeader, + headers: pk.headers, + header: pk.header, Hash: pk.Hash, Owner: pk.Owner, EgressRoute: pk.EgressRoute, @@ -106,4 +245,55 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NetworkProtocolNumber: pk.NetworkProtocolNumber, NatDone: pk.NatDone, } + return newPk +} + +// headerInfo stores metadata about a header in a packet. +type headerInfo struct { + // buf is the memorized slice for both prepended and consumed header. + // When header is prepended, buf serves as memorized value, which is a slice + // of pk.header. When header is consumed, buf is the slice pulled out from + // pk.Data, which is the only place to hold this header. + buf buffer.View + + // offset will be a negative number denoting the offset where this header is + // from the end of pk.header, if it is prepended. Otherwise, zero. + offset int +} + +// PacketHeader is a handle object to a header in the underlying packet. +type PacketHeader struct { + pk *PacketBuffer + typ headerType +} + +// View returns the underlying storage of h. +func (h PacketHeader) View() buffer.View { + return h.pk.headers[h.typ].buf +} + +// Push pushes size bytes in the front of its residing packet, and returns the +// backing storage. Callers may only call one of Push or Consume once on each +// header in the lifetime of the underlying packet. +func (h PacketHeader) Push(size int) buffer.View { + return h.pk.push(h.typ, size) +} + +// Consume moves the first size bytes of the unparsed data portion in the packet +// to h, and returns the backing storage. In the case of data is shorter than +// size, consumed will be false, and the state of h will not be affected. +// Callers may only call one of Push or Consume once on each header in the +// lifetime of the underlying packet. +func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { + return h.pk.consume(h.typ, size) +} + +// PayloadSince returns packet payload starting from and including a particular +// header. This method isn't optimized and should be used in test only. +func PayloadSince(h PacketHeader) buffer.View { + var v buffer.View + for _, hinfo := range h.pk.headers[h.typ:] { + v = append(v, hinfo.buf...) + } + return append(v, h.pk.Data.ToView()...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go new file mode 100644 index 000000000..c6fa8da5f --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestPacketHeaderPush(t *testing.T) { + for _, test := range []struct { + name string + reserved int + link []byte + network []byte + transport []byte + data []byte + }{ + { + name: "construct empty packet", + }, + { + name: "construct link header only packet", + reserved: 60, + link: makeView(10), + }, + { + name: "construct link and network header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + }, + { + name: "construct header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + }, + { + name: "construct data only packet", + data: makeView(40), + }, + { + name: "construct L3 packet", + reserved: 60, + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + { + name: "construct L2 packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + allHdrSize := len(test.link) + len(test.network) + len(test.transport) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Push headers. + if v := test.transport; len(v) > 0 { + copy(pk.TransportHeader().Push(len(v)), v) + } + if v := test.network; len(v) > 0 { + copy(pk.NetworkHeader().Push(len(v)), v) + } + if v := test.link; len(v) > 0 { + copy(pk.LinkHeader().Push(len(v)), v) + } + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), allHdrSize+len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), + concatViews(test.link, test.network, test.transport, test.data)) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(test.link, test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(test.transport, test.data)) + }) + } +} + +func TestPacketHeaderConsume(t *testing.T) { + for _, test := range []struct { + name string + data []byte + link int + network int + transport int + }{ + { + name: "parse L2 packet", + data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), + link: 10, + network: 20, + transport: 30, + }, + { + name: "parse L3 packet", + data: concatViews(makeView(20), makeView(30), makeView(40)), + network: 20, + transport: 30, + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Consume headers. + if size := test.link; size > 0 { + if _, ok := pk.LinkHeader().Consume(size); !ok { + t.Fatalf("pk.LinkHeader().Consume() = false, want true") + } + } + if size := test.network; size > 0 { + if _, ok := pk.NetworkHeader().Consume(size); !ok { + t.Fatalf("pk.NetworkHeader().Consume() = false, want true") + } + } + if size := test.transport; size > 0 { + if _, ok := pk.TransportHeader().Consume(size); !ok { + t.Fatalf("pk.TransportHeader().Consume() = false, want true") + } + } + + allHdrSize := test.link + test.network + test.transport + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), 0; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), 0; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + // After state of pk. + var ( + link = test.data[:test.link] + network = test.data[test.link:][:test.network] + transport = test.data[test.link+test.network:][:test.transport] + payload = test.data[allHdrSize:] + ) + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(link, network, transport, payload)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(network, transport, payload)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(transport, payload)) + }) + } +} + +func TestPacketHeaderConsumeDataTooShort(t *testing.T) { + data := makeView(10) + + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + // Consume should fail if pkt.Data is too short. + if _, ok := pk.LinkHeader().Consume(11); ok { + t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.NetworkHeader().Consume(11); ok { + t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.TransportHeader().Consume(11); ok { + t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") + } + + // Check packet should look the same as initial packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(data).ToVectorisedView(), + }) +} + +func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Second push should have panicked") + }) + } +} + +func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { + if _, ok := h.Consume(headerSize); !ok { + t.Fatal("First consume should succeed") + } + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Second consume should have panicked") + }) + } +} + +func TestPacketHeaderPushThenConsumePanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Consume should have panicked") + }) + } +} + +func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Consume(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Push should have panicked") + }) + } +} + +func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { + t.Helper() + reserved := opts.ReserveHeaderBytes + if got, want := pk.ReservedHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), 0; got != want { + t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) + } + data := opts.Data.ToView() + if got, want := pk.Size(), len(data); got != want { + t.Errorf("Initial pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) + // Check the initial values for each header. + checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) + checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) + checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) + // Check the initial valies for PayloadSince. + checkViewEqual(t, "Initial PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), data) +} + +func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { + t.Helper() + checkViewEqual(t, name+".View()", h.View(), want) +} + +func checkViewEqual(t *testing.T, what string, got, want buffer.View) { + t.Helper() + if !bytes.Equal(got, want) { + t.Errorf("%s = %x, want %x", what, got, want) + } +} + +func makeView(size int) buffer.View { + b := byte(size) + return bytes.Repeat([]byte{b}, size) +} + +func concatViews(views ...buffer.View) buffer.View { + var all buffer.View + for _, v := range views { + all = append(all, v...) + } + return all +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 8604c4259..4570e8969 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -249,8 +249,8 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It takes ownership of pkt. pkt.TransportHeader must have already - // been set. + // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // already been set. WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9ce0a2c22..e267bebb0 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -173,7 +173,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf } // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + numBytes := pkt.Size() err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { @@ -203,8 +203,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Header.UsedLength() - writtenBytes += pb.Data.Size() + writtenBytes += pb.Size() } r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index fe1c1b8a4..0273b3c63 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -102,7 +102,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -118,7 +118,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -143,10 +143,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. - pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) - pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] - pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] - pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) + hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + hdr[dstAddrOffset] = r.RemoteAddress[0] + hdr[srcAddrOffset] = f.id.LocalAddress[0] + hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { f.HandlePacket(r, pkt) @@ -249,12 +249,10 @@ func (*fakeNetworkProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(fakeNetHeaderLen) return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -315,9 +313,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -327,9 +325,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -339,9 +337,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -350,9 +348,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -362,9 +360,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -383,11 +381,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro } func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + })) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -442,9 +439,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -2285,9 +2282,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -2367,9 +2364,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) pkt, ok := ep2.Read() if !ok { @@ -2377,8 +2374,8 @@ func TestNICForwarding(t *testing.T) { } // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { - t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { + t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) } // Test that forwarding increments Tx stats correctly. diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 73dada928..1339edc2d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -128,11 +128,10 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) } func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { @@ -166,11 +165,10 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) } func TestTransportDemuxerRegister(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 7e8b84867..6c6e44468 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -84,16 +84,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) - hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(v).ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, + Data: buffer.View(v).ToVectorisedView(), + }) + _ = pkt.TransportHeader().Push(fakeTransHeaderLen) + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, nil, err } @@ -328,13 +328,8 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) - if !ok { - return false - } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(fakeTransHeaderLen) - return true + _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) + return ok } func fakeTransFactory() stack.TransportProtocol { @@ -382,9 +377,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -393,9 +388,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -404,9 +399,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -459,9 +454,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -470,9 +465,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -481,9 +476,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -636,9 +631,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: req.ToVectorisedView(), - }) + })) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -655,10 +650,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.NetworkHeader[0]; dst != 3 { + nh := stack.PayloadSince(p.Pkt.NetworkHeader()) + if dst := nh[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.NetworkHeader[1]; src != 1 { + if src := nh[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 091bc5281..07c85ce59 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -958,6 +958,26 @@ type SocketDetachFilterOption int // and port of a redirected packet. type OriginalDestinationOption FullAddress +// TCPTimeWaitReuseOption is used stack.(*Stack).TransportProtocolOption to +// specify if the stack can reuse the port bound by an endpoint in TIME-WAIT for +// new connections when it is safe from protocol viewpoint. +type TCPTimeWaitReuseOption uint8 + +const ( + // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot + // be reused for new connections. + TCPTimeWaitReuseDisabled TCPTimeWaitReuseOption = iota + + // TCPTimeWaitReuseGlobal indicates reuse of port bound by endponts in TIME-WAIT can + // be reused for new connections irrespective of the src/dest addresses. + TCPTimeWaitReuseGlobal + + // TCPTimeWaitReuseLoopbackOnly indicates reuse of port bound by endpoint in TIME-WAIT can + // only be reused if the connection was a connection over loopback. i.e src/dest adddresses + // are loopback addresses. + TCPTimeWaitReuseLoopbackOnly +) + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 0ff3a2b89..9f0dd4d6d 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -80,9 +80,9 @@ func TestPingMulticastBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { @@ -102,9 +102,9 @@ func TestPingMulticastBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } tests := []struct { @@ -204,7 +204,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader) + src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } @@ -252,9 +252,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { @@ -280,9 +280,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } tests := []struct { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 4612be4e7..bd6f49eb8 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -430,9 +430,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + }) + pkt.Owner = owner - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -447,15 +450,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + pkt.Data = data.ToVectorisedView() + if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: data.ToVectorisedView(), - TransportHeader: buffer.View(icmpv4), - Owner: owner, - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -463,9 +463,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + }) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) @@ -477,15 +479,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err dataVV := data.ToVectorisedView() icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) + pkt.Data = dataVV if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: dataVV, - TransportHeader: buffer.View(icmpv6), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } // checkV4MappedLocked determines the effective network protocol and converts @@ -748,14 +747,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.TransportHeader) + h := header.ICMPv4(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.TransportHeader) + h := header.ICMPv6(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -791,7 +794,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // ICMP socket's data includes ICMP header. - packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data = pkt.TransportHeader().View().ToVectorisedView() packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index df478115d..1b03ad6bb 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -433,9 +433,9 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet // TODO(gvisor.dev/issue/173): Return network protocol. - if len(pkt.LinkHeader) > 0 { + if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. - hdr := header.Ethernet(pkt.LinkHeader) + hdr := header.Ethernet(pkt.LinkHeader().View()) packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), @@ -458,9 +458,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, case tcpip.PacketHost: packet.data = pkt.Data case tcpip.PacketOutgoing: - // Strip Link Header from the Header. - pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) - combinedVV := pkt.Header.View().ToVectorisedView() + // Strip Link Header. + var combinedVV buffer.VectorisedView + if v := pkt.NetworkHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } + if v := pkt.TransportHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } combinedVV.Append(pkt.Data) packet.data = combinedVV default: @@ -471,9 +476,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - var combinedVV buffer.VectorisedView if pkt.PktType != tcpip.PacketOutgoing { - if len(pkt.LinkHeader) == 0 { + if pkt.LinkHeader().View().IsEmpty() { // We weren't provided with an actual ethernet header, // so fake one. ethFields := header.EthernetFields{ @@ -485,19 +489,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, fakeHeader.Encode(ðFields) linkHeader = buffer.View(fakeHeader) } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) } - combinedVV = linkHeader.ToVectorisedView() - } - if pkt.PktType == tcpip.PacketOutgoing { - // For outgoing packets the Link, Network and Transport - // headers are in the pkt.Header fields normally unless - // a Raw socket is in use in which case pkt.Header could - // be nil. - combinedVV.AppendView(pkt.Header.View()) + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } else { + packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } - combinedVV.Append(pkt.Data) - packet.data = combinedVV } packet.timestampNS = ep.stack.Clock().NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index f85a68554..edc2b5b61 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -352,18 +352,23 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } if e.hdrIncluded { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { + }) + if err := route.WriteHeaderIncludedPacket(pkt); err != nil { return 0, nil, err } } else { - hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(payloadBytes).ToVectorisedView(), - Owner: e.owner, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()), + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) + pkt.Owner = e.owner + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: e.TransProto, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, pkt); err != nil { return 0, nil, err } } @@ -691,12 +696,13 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // slice. Save/restore doesn't support overlapping slices and will fail. var combinedVV buffer.VectorisedView if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { - headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) - headers = append(headers, pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(network)+len(transport)) + headers = append(headers, network...) + headers = append(headers, transport...) combinedVV = headers.ToVectorisedView() } else { - combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() } combinedVV.Append(pkt.Data) packet.data = combinedVV diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 913ea6535..b706438bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -212,7 +212,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) - n.amss = mssForRoute(&n.route) + n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) @@ -380,6 +380,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.portFlags = e.portFlags n.boundBindToDevice = e.boundBindToDevice n.boundPortFlags = e.boundPortFlags + n.userMSS = e.userMSS } // reserveTupleLocked reserves an accepted endpoint's tuple. @@ -481,9 +482,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } - // TODO(b/143300739): Use the userMSS of the listening socket - // for accepted sockets. - switch { case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) @@ -514,16 +512,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) // Send SYN without window scaling because we currently - // dont't encode this information in the cookie. + // don't encode this information in the cookie. // // Enable Timestamp option if the original syn did have // the timestamp option specified. + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: mssForRoute(&s.route), + MSS: calculateAdvertisedMSS(e.userMSS, s.route), } e.sendSynTCP(&s.route, tcpFields{ id: s.id, diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 8dd759ba2..290172ac9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -746,11 +746,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) - hdr := &pkt.Header - packetSize := pkt.Data.Size() - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) - pkt.TransportHeader = buffer.View(tcp) + tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) tcp.Encode(&header.TCPFields{ SrcPort: tf.id.LocalPort, DstPort: tf.id.RemotePort, @@ -762,8 +758,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta }) copy(tcp[header.TCPMinimumSize:], tf.opts) - length := uint16(hdr.UsedLength() + packetSize) - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We @@ -801,17 +796,18 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso packetSize = size } size -= packetSize - var pkt stack.PacketBuffer - pkt.Header = buffer.NewPrependable(hdrSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrSize, + }) pkt.Hash = tf.txHash pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() data.ReadToVV(&pkt.Data, packetSize) - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) - pkts.PushBack(&pkt) + pkts.PushBack(pkt) } if tf.ttl == 0 { @@ -837,12 +833,12 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := &stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - Data: data, - Hash: tf.txHash, - Owner: owner, - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, + Data: data, + }) + pkt.Hash = tf.txHash + pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { @@ -1706,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 || n¬ifyAbort != 0 { + if n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b8b52b03d..1ccedebcc 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -449,10 +449,11 @@ type endpoint struct { // recentTS is the timestamp that should be sent in the TSEcr field of // the timestamp for future segments sent by the endpoint. This field is // updated if required when a new segment is received by this endpoint. - // - // recentTS must be read/written atomically. recentTS uint32 + // recentTSTime is the unix time when we updated recentTS last. + recentTSTime time.Time `state:".(unixTime)"` + // tsOffset is a randomized offset added to the value of the // TSVal field in the timestamp option. tsOffset uint32 @@ -666,7 +667,8 @@ func (e *endpoint) UniqueID() uint64 { // r, it will be used; otherwise, the maximum possible MSS will be used. func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { // The maximum possible MSS is dependent on the route. - maxMSS := mssForRoute(&r) + // TODO(b/143359391): Respect TCP Min and Max size. + maxMSS := uint16(r.MTU() - header.TCPMinimumSize) if userMSS != 0 && userMSS < maxMSS { return userMSS @@ -795,15 +797,15 @@ func (e *endpoint) EndpointState() EndpointState { return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) } -// setRecentTimestamp atomically sets the recentTS field to the -// provided value. +// setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { - atomic.StoreUint32(&e.recentTS, recentTS) + e.recentTS = recentTS + e.recentTSTime = time.Now() } -// recentTimestamp atomically reads and returns the value of the recentTS field. +// recentTimestamp returns the value of the recentTS field. func (e *endpoint) recentTimestamp() uint32 { - return atomic.LoadUint32(&e.recentTS) + return e.recentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -902,7 +904,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case StateClose, StateError: + case StateClose, StateError, StateTimeWait: // Ready for anything. result = mask @@ -2148,12 +2150,66 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc h.Write(portBuf) portOffset := h.Sum32() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err)) + } + + reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal + if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { + switch netProto { + case header.IPv4ProtocolNumber: + reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + case header.IPv6ProtocolNumber: + reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + } + } + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { - return false, nil + if err != tcpip.ErrPortInUse || !reuse { + return false, nil + } + transEPID := e.ID + transEPID.LocalPort = p + // Check if an endpoint is registered with demuxer in TIME-WAIT and if + // we can reuse it. If we can't find a transport endpoint then we just + // skip using this port as it's possible that either an endpoint has + // bound the port but not registered with demuxer yet (no listen/connect + // done yet) or the reservation was freed between the check above and + // the FindTransportEndpoint below. But rather than retry the same port + // we just skip it and move on. + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + if transEP == nil { + // ReservePort failed but there is no registered endpoint with + // demuxer. Which indicates there is at least some endpoint that has + // bound the port. + return false, nil + } + + tcpEP := transEP.(*endpoint) + tcpEP.LockUser() + // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but + // less than 1 second has elapsed since its recentTS was updated then + // we cannot reuse the port. + if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + tcpEP.UnlockUser() + return false, nil + } + // Since the endpoint is in TIME-WAIT it should be safe to acquire its + // Lock while holding the lock for this endpoint as endpoints in + // TIME-WAIT do not acquire locks on other endpoints. + tcpEP.workerCleanup = false + tcpEP.cleanupLocked() + tcpEP.notifyProtocolGoroutine(notifyAbort) + tcpEP.UnlockUser() + // Now try and Reserve again if it fails then we skip. + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + return false, nil + } } id := e.ID @@ -2911,8 +2967,3 @@ func (e *endpoint) Wait() { <-notifyCh } } - -func mssForRoute(r *stack.Route) uint16 { - // TODO(b/143359391): Respect TCP Min and Max size. - return uint16(r.MTU() - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index abf1ac5c9..723e47ddc 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -309,6 +309,16 @@ func (e *endpoint) loadLastError(s string) { e.lastError = tcpip.StringToError(s) } +// saveRecentTSTime is invoked by stateify. +func (e *endpoint) saveRecentTSTime() unixTime { + return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} +} + +// loadRecentTSTime is invoked by stateify. +func (e *endpoint) loadRecentTSTime(unix unixTime) { + e.recentTSTime = time.Unix(unix.second, unix.nano) +} + // saveHardError is invoked by stateify. func (e *EndpointInfo) saveHardError() string { if e.HardError == nil { diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2e5093b36..c5afa2680 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,7 +21,6 @@ package tcp import ( - "fmt" "runtime" "strings" "time" @@ -191,8 +190,9 @@ type protocol struct { congestionControl string availableCongestionControl []string moderateReceiveBuffer bool - tcpLingerTimeout time.Duration - tcpTimeWaitTimeout time.Duration + lingerTimeout time.Duration + timeWaitTimeout time.Duration + timeWaitReuse tcpip.TCPTimeWaitReuseOption minRTO time.Duration maxRTO time.Duration maxRetries uint32 @@ -358,7 +358,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { v = 0 } p.mu.Lock() - p.tcpLingerTimeout = time.Duration(v) + p.lingerTimeout = time.Duration(v) p.mu.Unlock() return nil @@ -367,7 +367,16 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { v = 0 } p.mu.Lock() - p.tcpTimeWaitTimeout = time.Duration(v) + p.timeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitReuseOption: + if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.timeWaitReuse = v p.mu.Unlock() return nil @@ -468,13 +477,19 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { case *tcpip.TCPLingerTimeoutOption: p.mu.RLock() - *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout) p.mu.RUnlock() return nil case *tcpip.TCPTimeWaitTimeoutOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout) + p.mu.RUnlock() + return nil + + case *tcpip.TCPTimeWaitReuseOption: + p.mu.RLock() + *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) p.mu.RUnlock() return nil @@ -531,22 +546,22 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) if !ok { return false } // If the header has options, pull those up as well. if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(offset) - if !ok { - panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) - } + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(len(hdr)) - return true + _, ok = pkt.TransportHeader().Consume(hdrLen) + return ok } // NewProtocol returns a TCP transport protocol. @@ -564,8 +579,9 @@ func NewProtocol() stack.TransportProtocol { }, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, - tcpLingerTimeout: DefaultTCPLingerTimeout, - tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + lingerTimeout: DefaultTCPLingerTimeout, + timeWaitTimeout: DefaultTCPTimeWaitTimeout, + timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, synRetries: DefaultSynRetries, minRTO: MinRTO, diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index bb60dc29d..94307d31a 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -68,7 +68,7 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) - s.hdr = header.TCP(pkt.TransportHeader) + s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() return s } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 1b58eb91b..55ae09a2f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -750,128 +750,227 @@ func TestSimpleReceive(t *testing.T) { ) } -// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when -// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when +// creating a new active TCP socket. It should be present in the sent TCP // SYN segment. -func TestUserSuppliedMSSOnConnectV4(t *testing.T) { +func TestUserSuppliedMSSOnConnect(t *testing.T) { const mtu = 5000 - const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS int - expMSS uint16 + + ips := []struct { + name string + createEP func(*context.Context) + connectAddr tcpip.Address + checker func(*testing.T, *context.Context, uint16, int) + maxMSS uint16 }{ { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + connectAddr: context.TestAddr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, }, { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(true) + }, + connectAddr: context.TestV6Addr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } - c.Create(-1) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) - } + ip.createEP(c) - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + // Set the MSS socket option. + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } - // Start connection attempt to IPv4 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) - } + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - // Receive SYN packet with our user supplied MSS. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} + if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("Connect(%+v): %s", connectAddr, err) + } + + // Receive SYN packet with our user supplied MSS. + ip.checker(t, c, test.expMSS, ws) + }) + } }) } } -// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when -// creating a new active IPv6 TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnectV6(t *testing.T) { - const mtu = 5000 - const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS uint16 - expMSS uint16 +// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used +// when completing the handshake for a new TCP connection from a TCP +// listening socket. It should be present in the sent TCP SYN-ACK segment. +func TestUserSuppliedMSSOnListenAccept(t *testing.T) { + const ( + nonSynCookieAccepts = 2 + totalAccepts = 4 + mtu = 5000 + ) + + ips := []struct { + name string + createEP func(*context.Context) + sendPkt func(*context.Context, *context.Headers) + checker func(*testing.T, *context.Context, uint16, uint16) + maxMSS uint16 }{ { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendPacket(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, }, { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(false) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendV6Packet(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } - c.CreateV6Endpoint(true) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) - } + ip.createEP(c) - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + // Set the SynRcvd threshold to force a syn cookie based accept to happen. + opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err) + } - // Start connection attempt to IPv6 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) - } + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } - // Receive SYN packet with our user supplied MSS. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + bindAddr := tcpip.FullAddress{Port: context.StackPort} + if err := c.EP.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s:", bindAddr, err) + } + + if err := c.EP.Listen(totalAccepts); err != nil { + t.Fatalf("Listen(%d): %s:", totalAccepts, err) + } + + // The first nonSynCookieAccepts packets sent will trigger a gorooutine + // based accept. The rest will trigger a cookie based accept. + for i := 0; i < totalAccepts; i++ { + // Send a SYN requests. + iss := seqnum.Value(i) + srcPort := context.TestPort + uint16(i) + ip.sendPkt(c, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + ip.checker(t, c, srcPort, test.expMSS) + } + }) + } }) } } - func TestSendRstOnListenerRxSynAckV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -7297,3 +7396,53 @@ func TestResetDuringClose(t *testing.T) { wg.Wait() } + +func TestStackTimeWaitReuse(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + s := c.Stack() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) + } + if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { + t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + } +} + +func TestSetStackTimeWaitReuse(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + s := c.Stack() + testCases := []struct { + v int + err *tcpip.Error + }{ + {int(tcpip.TCPTimeWaitReuseDisabled), nil}, + {int(tcpip.TCPTimeWaitReuseGlobal), nil}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue}, + {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue}, + } + + for _, tc := range testCases { + err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v)) + if got, want := err, tc.err; got != want { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err) + } + if tc.err != nil { + continue + } + + var twReuse tcpip.TCPTimeWaitReuseOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) + } + + if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { + t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + } + } +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 37e7767d6..b6031354e 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -257,8 +257,8 @@ func (c *Context) GetPacket() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -284,15 +284,15 @@ func (c *Context) GetPacketNonBlocking() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b } // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { +func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) if len(buf) > maxTotalSize { @@ -318,9 +318,10 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -374,26 +375,29 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: s, }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegment(payload, h), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacketWithAddrs builds and sends a TCP segment(with the provided payload // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendAck sends an ACK packet. @@ -514,9 +518,8 @@ func (c *Context) GetV6Packet() []byte { if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b @@ -566,9 +569,10 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) } // CreateConnected creates a connected TCP endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4a2b6c03a..73608783c 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -986,13 +986,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { - // Allocate a buffer for the UDP header. - hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: data, + }) + pkt.Owner = owner - // Initialize the header. - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ SrcPort: localPort, DstPort: remotePort, @@ -1019,12 +1022,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Protocol: ProtocolNumber, TTL: ttl, TOS: tos, - }, &stack.PacketBuffer{ - Header: hdr, - Data: data, - TransportHeader: buffer.View(udp), - Owner: owner, - }); err != nil { + }, pkt); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } @@ -1372,7 +1370,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() @@ -1443,9 +1441,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Save any useful information from the network header to the packet. switch r.NetProto { case header.IPv4ProtocolNumber: - packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: - packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS() } // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 0e7464e3a..63d4bed7c 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -82,7 +82,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() @@ -130,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -139,22 +139,21 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // For example, a raw or packet socket may use what UDP // considers an unreachable destination. Thus we deep copy pkt // to prevent multiple ownership and SR errors. - newHeader := append(buffer.View(nil), pkt.NetworkHeader...) - newHeader = append(newHeader, pkt.TransportHeader...) + newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...) + newHeader = append(newHeader, pkt.TransportHeader().View()...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4DstUnreachable) - pkt.SetCode(header.ICMPv4PortUnreachable) - pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, }) + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -175,24 +174,24 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + payloadLen := len(network) + len(transport) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) + payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport}) payload.Append(pkt.Data) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) - pkt.SetType(header.ICMPv6DstUnreachable) - pkt.SetCode(header.ICMPv6PortUnreachable) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, }) + icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) } return true } @@ -215,14 +214,8 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Packet is too small - return false - } - pkt.TransportHeader = h - pkt.Data.TrimFront(header.UDPMinimumSize) - return true + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + return ok } // NewProtocol returns a UDP transport protocol. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 1a32622ca..71776d6db 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -388,8 +388,8 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() h := flow.header4Tuple(outgoing) checkers = append( @@ -410,14 +410,14 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } else { buf := c.buildV6Packet(payload, &h) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } } @@ -804,9 +804,9 @@ func TestV4ReadSelfSource(t *testing.T) { h.srcAddr = h.dstAddr buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) @@ -1766,9 +1766,8 @@ func TestV4UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1844,9 +1843,8 @@ func TestV6UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1897,9 +1895,9 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetLength(u.Length() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { @@ -1952,9 +1950,9 @@ func TestShortHeader(t *testing.T) { copy(buf[header.IPv6MinimumSize:], udpHdr) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) @@ -1986,9 +1984,9 @@ func TestIncrementChecksumErrorsV4(t *testing.T) { } } - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2019,9 +2017,9 @@ func TestIncrementChecksumErrorsV6(t *testing.T) { u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(u.Checksum() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2049,9 +2047,9 @@ func TestPayloadModifiedV4(t *testing.T) { buf := c.buildV4Packet(payload, &h) // Modify the payload so that the checksum value in the UDP header will be incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2079,9 +2077,9 @@ func TestPayloadModifiedV6(t *testing.T) { buf := c.buildV6Packet(payload, &h) // Modify the payload so that the checksum value in the UDP header will be incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2110,9 +2108,9 @@ func TestChecksumZeroV4(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv4MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 0 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2141,9 +2139,9 @@ func TestChecksumZeroV6(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 5a9dd8bd8..952871f95 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -88,44 +88,74 @@ func EnsureSupportedDockerVersion() { // RuntimePath returns the binary path for the current runtime. func RuntimePath() (string, error) { + rs, err := runtimeMap() + if err != nil { + return "", err + } + + p, ok := rs["path"].(string) + if !ok { + // The runtime does not declare a path. + return "", fmt.Errorf("runtime does not declare a path: %v", rs) + } + return p, nil +} + +// UsingVFS2 returns true if the 'runtime' has the vfs2 flag set. +// TODO(gvisor.dev/issue/1624): Remove. +func UsingVFS2() (bool, error) { + rMap, err := runtimeMap() + if err != nil { + return false, err + } + + list, ok := rMap["runtimeArgs"].([]interface{}) + if !ok { + return false, fmt.Errorf("unexpected format: %v", rMap) + } + + for _, element := range list { + if element == "--vfs2" { + return true, nil + } + } + return false, nil +} + +func runtimeMap() (map[string]interface{}, error) { // Read the configuration data; the file must exist. configBytes, err := ioutil.ReadFile(*config) if err != nil { - return "", err + return nil, err } // Unmarshal the configuration. c := make(map[string]interface{}) if err := json.Unmarshal(configBytes, &c); err != nil { - return "", err + return nil, err } // Decode the expected configuration. r, ok := c["runtimes"] if !ok { - return "", fmt.Errorf("no runtimes declared: %v", c) + return nil, fmt.Errorf("no runtimes declared: %v", c) } rs, ok := r.(map[string]interface{}) if !ok { // The runtimes are not a map. - return "", fmt.Errorf("unexpected format: %v", c) + return nil, fmt.Errorf("unexpected format: %v", rs) } r, ok = rs[*runtime] if !ok { // The expected runtime is not declared. - return "", fmt.Errorf("runtime %q not found: %v", *runtime, c) + return nil, fmt.Errorf("runtime %q not found: %v", *runtime, rs) } rs, ok = r.(map[string]interface{}) if !ok { // The runtime is not a map. - return "", fmt.Errorf("unexpected format: %v", c) - } - p, ok := rs["path"].(string) - if !ok { - // The runtime does not declare a path. - return "", fmt.Errorf("unexpected format: %v", c) + return nil, fmt.Errorf("unexpected format: %v", r) } - return p, nil + return rs, nil } // Save exports a container image to the given Writer. |