diff options
Diffstat (limited to 'pkg')
259 files changed, 15813 insertions, 3452 deletions
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 1e23850a9..67646f837 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -242,7 +242,7 @@ const ( // Statx represents struct statx. // -// +marshal +// +marshal slice:StatxSlice type Statx struct { Mask uint32 Blksize uint32 @@ -270,6 +270,8 @@ type Statx struct { var SizeOfStatx = (*Statx)(nil).SizeBytes() // FileMode represents a mode_t. +// +// +marshal type FileMode uint16 // Permissions returns just the permission bits. diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 29062c97a..4526d3f95 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -180,8 +180,12 @@ var ( // to get quote. const SizeOfQuoteInputData = 64 -// SignReport is a struct that gets signed quote from input data. +// SignReport is a struct that gets signed quote from input data. The +// serialized quote is copied to buf. +// size is an input that specifies the size of buf. When returned, it's updated +// to the size of quote. type SignReport struct { - data [64]byte - quote []byte + data [64]byte + size uint32 + buf []byte } diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 95871b8a5..f60e42997 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -542,6 +542,15 @@ type ControlMessageIPPacketInfo struct { DestinationAddr InetAddr } +// ControlMessageIPv6PacketInfo represents struct in6_pktinfo from linux/ipv6.h. +// +// +marshal +// +stateify savable +type ControlMessageIPv6PacketInfo struct { + Addr Inet6Addr + NIC uint32 +} + // SizeOfControlMessageCredentials is the binary size of a // ControlMessageCredentials struct. var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes() @@ -566,6 +575,10 @@ const SizeOfControlMessageTClass = 4 // control message. const SizeOfControlMessageIPPacketInfo = 12 +// SizeOfControlMessageIPv6PacketInfo is the size of a +// ControlMessageIPv6PacketInfo. +const SizeOfControlMessageIPv6PacketInfo = 20 + // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 diff --git a/pkg/atomicbitops/atomicbitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s index 54c887ee5..6b9a67adc 100644 --- a/pkg/atomicbitops/atomicbitops_amd64.s +++ b/pkg/atomicbitops/atomicbitops_amd64.s @@ -16,28 +16,28 @@ #include "textflag.h" -TEXT ·AndUint32(SB),$0-12 - MOVQ addr+0(FP), BP +TEXT ·AndUint32(SB),NOSPLIT,$0-12 + MOVQ addr+0(FP), BX MOVL val+8(FP), AX LOCK - ANDL AX, 0(BP) + ANDL AX, 0(BX) RET -TEXT ·OrUint32(SB),$0-12 - MOVQ addr+0(FP), BP +TEXT ·OrUint32(SB),NOSPLIT,$0-12 + MOVQ addr+0(FP), BX MOVL val+8(FP), AX LOCK - ORL AX, 0(BP) + ORL AX, 0(BX) RET -TEXT ·XorUint32(SB),$0-12 - MOVQ addr+0(FP), BP +TEXT ·XorUint32(SB),NOSPLIT,$0-12 + MOVQ addr+0(FP), BX MOVL val+8(FP), AX LOCK - XORL AX, 0(BP) + XORL AX, 0(BX) RET -TEXT ·CompareAndSwapUint32(SB),$0-20 +TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20 MOVQ addr+0(FP), DI MOVL old+8(FP), AX MOVL new+12(FP), DX @@ -46,28 +46,28 @@ TEXT ·CompareAndSwapUint32(SB),$0-20 MOVL AX, ret+16(FP) RET -TEXT ·AndUint64(SB),$0-16 - MOVQ addr+0(FP), BP +TEXT ·AndUint64(SB),NOSPLIT,$0-16 + MOVQ addr+0(FP), BX MOVQ val+8(FP), AX LOCK - ANDQ AX, 0(BP) + ANDQ AX, 0(BX) RET -TEXT ·OrUint64(SB),$0-16 - MOVQ addr+0(FP), BP +TEXT ·OrUint64(SB),NOSPLIT,$0-16 + MOVQ addr+0(FP), BX MOVQ val+8(FP), AX LOCK - ORQ AX, 0(BP) + ORQ AX, 0(BX) RET -TEXT ·XorUint64(SB),$0-16 - MOVQ addr+0(FP), BP +TEXT ·XorUint64(SB),NOSPLIT,$0-16 + MOVQ addr+0(FP), BX MOVQ val+8(FP), AX LOCK - XORQ AX, 0(BP) + XORQ AX, 0(BX) RET -TEXT ·CompareAndSwapUint64(SB),$0-32 +TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32 MOVQ addr+0(FP), DI MOVQ old+8(FP), AX MOVQ new+16(FP), DX diff --git a/pkg/atomicbitops/atomicbitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s index 5c780851b..644a6bca5 100644 --- a/pkg/atomicbitops/atomicbitops_arm64.s +++ b/pkg/atomicbitops/atomicbitops_arm64.s @@ -16,7 +16,7 @@ #include "textflag.h" -TEXT ·AndUint32(SB),$0-12 +TEXT ·AndUint32(SB),NOSPLIT,$0-12 MOVD ptr+0(FP), R0 MOVW val+8(FP), R1 again: @@ -26,7 +26,7 @@ again: CBNZ R3, again RET -TEXT ·OrUint32(SB),$0-12 +TEXT ·OrUint32(SB),NOSPLIT,$0-12 MOVD ptr+0(FP), R0 MOVW val+8(FP), R1 again: @@ -36,7 +36,7 @@ again: CBNZ R3, again RET -TEXT ·XorUint32(SB),$0-12 +TEXT ·XorUint32(SB),NOSPLIT,$0-12 MOVD ptr+0(FP), R0 MOVW val+8(FP), R1 again: @@ -46,7 +46,7 @@ again: CBNZ R3, again RET -TEXT ·CompareAndSwapUint32(SB),$0-20 +TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20 MOVD addr+0(FP), R0 MOVW old+8(FP), R1 MOVW new+12(FP), R2 @@ -60,7 +60,7 @@ done: MOVW R3, prev+16(FP) RET -TEXT ·AndUint64(SB),$0-16 +TEXT ·AndUint64(SB),NOSPLIT,$0-16 MOVD ptr+0(FP), R0 MOVD val+8(FP), R1 again: @@ -70,7 +70,7 @@ again: CBNZ R3, again RET -TEXT ·OrUint64(SB),$0-16 +TEXT ·OrUint64(SB),NOSPLIT,$0-16 MOVD ptr+0(FP), R0 MOVD val+8(FP), R1 again: @@ -80,7 +80,7 @@ again: CBNZ R3, again RET -TEXT ·XorUint64(SB),$0-16 +TEXT ·XorUint64(SB),NOSPLIT,$0-16 MOVD ptr+0(FP), R0 MOVD val+8(FP), R1 again: @@ -90,7 +90,7 @@ again: CBNZ R3, again RET -TEXT ·CompareAndSwapUint64(SB),$0-32 +TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32 MOVD addr+0(FP), R0 MOVD old+8(FP), R1 MOVD new+16(FP), R2 diff --git a/pkg/atomicbitops/atomicbitops_noasm.go b/pkg/atomicbitops/atomicbitops_noasm.go index 474c0c815..af6b1362e 100644 --- a/pkg/atomicbitops/atomicbitops_noasm.go +++ b/pkg/atomicbitops/atomicbitops_noasm.go @@ -21,6 +21,7 @@ import ( "sync/atomic" ) +//go:nosplit func AndUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -31,6 +32,7 @@ func AndUint32(addr *uint32, val uint32) { } } +//go:nosplit func OrUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -41,6 +43,7 @@ func OrUint32(addr *uint32, val uint32) { } } +//go:nosplit func XorUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -51,6 +54,7 @@ func XorUint32(addr *uint32, val uint32) { } } +//go:nosplit func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) { for { prev = atomic.LoadUint32(addr) @@ -63,6 +67,7 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) { } } +//go:nosplit func AndUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -73,6 +78,7 @@ func AndUint64(addr *uint64, val uint64) { } } +//go:nosplit func OrUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -83,6 +89,7 @@ func OrUint64(addr *uint64, val uint64) { } } +//go:nosplit func XorUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -93,6 +100,7 @@ func XorUint64(addr *uint64, val uint64) { } } +//go:nosplit func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) { for { prev = atomic.LoadUint64(addr) diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 796efa240..59784eacb 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -509,6 +509,24 @@ func TestView(t *testing.T) { } } +func TestViewClone(t *testing.T) { + const ( + originalSize = 90 + bytesToDelete = 30 + ) + var v View + v.AppendOwned(bytes.Repeat([]byte{originalSize}, originalSize)) + + clonedV := v.Clone() + v.TrimFront(bytesToDelete) + if got, want := int(v.Size()), originalSize-bytesToDelete; got != want { + t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) + } + if got := clonedV.Size(); got != originalSize { + t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) + } +} + func TestViewPullUp(t *testing.T) { for _, tc := range []struct { desc string diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go index 69e867386..28eba2ff6 100644 --- a/pkg/crypto/crypto_stdlib.go +++ b/pkg/crypto/crypto_stdlib.go @@ -19,14 +19,21 @@ package crypto import ( "crypto/ecdsa" + "crypto/elliptic" "crypto/sha512" + "fmt" "math/big" ) -// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the -// public key, pub. Its return value records whether the signature is valid. -func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) { - return ecdsa.Verify(pub, hash, r, s), nil +// EcdsaP384Sha384Verify verifies the signature in r, s of hash using ECDSA +// P384 + SHA 384 and the public key, pub. Its return value records whether +// the signature is valid. +func EcdsaP384Sha384Verify(pub *ecdsa.PublicKey, data []byte, r, s *big.Int) (bool, error) { + if pub.Curve != elliptic.P384() { + return false, fmt.Errorf("unsupported key curve: want P-384, got %v", pub.Curve) + } + digest := sha512.Sum384(data) + return ecdsa.Verify(pub, digest[:], r, s), nil } // SumSha384 returns the SHA384 checksum of the data. diff --git a/pkg/eventfd/BUILD b/pkg/eventfd/BUILD new file mode 100644 index 000000000..02407cb99 --- /dev/null +++ b/pkg/eventfd/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "eventfd", + srcs = [ + "eventfd.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/hostarch", + "//pkg/tcpip/link/rawfile", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "eventfd_test", + srcs = ["eventfd_test.go"], + library = ":eventfd", +) diff --git a/pkg/eventfd/eventfd.go b/pkg/eventfd/eventfd.go new file mode 100644 index 000000000..acdac01b8 --- /dev/null +++ b/pkg/eventfd/eventfd.go @@ -0,0 +1,115 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package eventfd wraps Linux's eventfd(2) syscall. +package eventfd + +import ( + "fmt" + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" +) + +const sizeofUint64 = 8 + +// Eventfd represents a Linux eventfd object. +type Eventfd struct { + fd int +} + +// Create returns an initialized eventfd. +func Create() (Eventfd, error) { + fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0) + if err != 0 { + return Eventfd{}, fmt.Errorf("failed to create eventfd: %v", error(err)) + } + if err := unix.SetNonblock(int(fd), true); err != nil { + unix.Close(int(fd)) + return Eventfd{}, err + } + return Eventfd{int(fd)}, nil +} + +// Wrap returns an initialized Eventfd using the provided fd. +func Wrap(fd int) Eventfd { + return Eventfd{fd} +} + +// Close closes the eventfd, after which it should not be used. +func (ev Eventfd) Close() error { + return unix.Close(ev.fd) +} + +// Dup copies the eventfd, calling dup(2) on the underlying file descriptor. +func (ev Eventfd) Dup() (Eventfd, error) { + other, err := unix.Dup(ev.fd) + if err != nil { + return Eventfd{}, fmt.Errorf("failed to dup: %v", other) + } + return Eventfd{other}, nil +} + +// Notify alerts other users of the eventfd. Users can receive alerts by +// calling Wait or Read. +func (ev Eventfd) Notify() error { + return ev.Write(1) +} + +// Write writes a specific value to the eventfd. +func (ev Eventfd) Write(val uint64) error { + var buf [sizeofUint64]byte + hostarch.ByteOrder.PutUint64(buf[:], val) + for { + n, err := unix.Write(ev.fd, buf[:]) + if err == unix.EINTR { + continue + } + if n != sizeofUint64 { + panic(fmt.Sprintf("short write to eventfd: got %d bytes, wanted %d", n, sizeofUint64)) + } + return err + } +} + +// Wait blocks until eventfd is non-zero (i.e. someone calls Notify or Write). +func (ev Eventfd) Wait() error { + _, err := ev.Read() + return err +} + +// Read blocks until eventfd is non-zero (i.e. someone calls Notify or Write) +// and returns the value read. +func (ev Eventfd) Read() (uint64, error) { + var tmp [sizeofUint64]byte + n, err := rawfile.BlockingReadUntranslated(ev.fd, tmp[:]) + if err != 0 { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + if n != sizeofUint64 { + panic(fmt.Sprintf("short read from eventfd: got %d bytes, wanted %d", n, sizeofUint64)) + } + return hostarch.ByteOrder.Uint64(tmp[:]), nil +} + +// FD returns the underlying file descriptor. Use with care, as this breaks the +// Eventfd abstraction. +func (ev Eventfd) FD() int { + return ev.fd +} diff --git a/pkg/eventfd/eventfd_test.go b/pkg/eventfd/eventfd_test.go new file mode 100644 index 000000000..96998d530 --- /dev/null +++ b/pkg/eventfd/eventfd_test.go @@ -0,0 +1,75 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eventfd + +import ( + "testing" + "time" +) + +func TestReadWrite(t *testing.T) { + efd, err := Create() + if err != nil { + t.Fatalf("failed to Create(): %v", err) + } + defer efd.Close() + + // Make sure we can read actual values + const want = 343 + if err := efd.Write(want); err != nil { + t.Fatalf("failed to write value: %d", want) + } + + got, err := efd.Read() + if err != nil { + t.Fatalf("failed to read value: %v", err) + } + if got != want { + t.Fatalf("Read(): got %d, but wanted %d", got, want) + } +} + +func TestWait(t *testing.T) { + efd, err := Create() + if err != nil { + t.Fatalf("failed to Create(): %v", err) + } + defer efd.Close() + + // There's no way to test with certainty that Wait() blocks indefinitely, but + // as a best-effort we can wait a bit on it. + errCh := make(chan error) + go func() { + errCh <- efd.Wait() + }() + select { + case err := <-errCh: + t.Fatalf("Wait() returned without a call to Notify(): %v", err) + case <-time.After(500 * time.Millisecond): + } + + // Notify and check that Wait() returned. + if err := efd.Notify(); err != nil { + t.Fatalf("Notify() failed: %v", err) + } + select { + case err := <-errCh: + if err != nil { + t.Fatalf("Read() failed: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("Read() did not return after Notify()") + } +} diff --git a/pkg/lisafs/BUILD b/pkg/lisafs/BUILD new file mode 100644 index 000000000..313c1756d --- /dev/null +++ b/pkg/lisafs/BUILD @@ -0,0 +1,117 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_template_instance( + name = "control_fd_refs", + out = "control_fd_refs.go", + package = "lisafs", + prefix = "controlFD", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "ControlFD", + }, +) + +go_template_instance( + name = "open_fd_refs", + out = "open_fd_refs.go", + package = "lisafs", + prefix = "openFD", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "OpenFD", + }, +) + +go_template_instance( + name = "control_fd_list", + out = "control_fd_list.go", + package = "lisafs", + prefix = "controlFD", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*ControlFD", + "Linker": "*ControlFD", + }, +) + +go_template_instance( + name = "open_fd_list", + out = "open_fd_list.go", + package = "lisafs", + prefix = "openFD", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*OpenFD", + "Linker": "*OpenFD", + }, +) + +go_library( + name = "lisafs", + srcs = [ + "channel.go", + "client.go", + "client_file.go", + "communicator.go", + "connection.go", + "control_fd_list.go", + "control_fd_refs.go", + "fd.go", + "handlers.go", + "lisafs.go", + "message.go", + "open_fd_list.go", + "open_fd_refs.go", + "sample_message.go", + "server.go", + "sock.go", + ], + marshal = True, + deps = [ + "//pkg/abi/linux", + "//pkg/cleanup", + "//pkg/context", + "//pkg/fdchannel", + "//pkg/flipcall", + "//pkg/fspath", + "//pkg/hostarch", + "//pkg/log", + "//pkg/marshal/primitive", + "//pkg/p9", + "//pkg/refsvfs2", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "sock_test", + size = "small", + srcs = ["sock_test.go"], + library = ":lisafs", + deps = [ + "//pkg/marshal", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "connection_test", + size = "small", + srcs = ["connection_test.go"], + deps = [ + ":lisafs", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md index 51d0d40e5..6b857321a 100644 --- a/pkg/lisafs/README.md +++ b/pkg/lisafs/README.md @@ -1,5 +1,8 @@ # Replacing 9P +NOTE: LISAFS is **NOT** production ready. There are still some security concerns +that must be resolved first. + ## Background The Linux filesystem model consists of the following key aspects (modulo mounts, diff --git a/pkg/lisafs/channel.go b/pkg/lisafs/channel.go new file mode 100644 index 000000000..301212e51 --- /dev/null +++ b/pkg/lisafs/channel.go @@ -0,0 +1,190 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math" + "runtime" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +var ( + chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes()) +) + +// maxChannels returns the number of channels a client can create. +// +// The server will reject channel creation requests beyond this (per client). +// Note that we don't want the number of channels to be too large, because each +// accounts for a large region of shared memory. +// TODO(gvisor.dev/issue/6313): Tune the number of channels. +func maxChannels() int { + maxChans := runtime.GOMAXPROCS(0) + if maxChans < 2 { + maxChans = 2 + } + if maxChans > 4 { + maxChans = 4 + } + return maxChans +} + +// channel implements Communicator and represents the communication endpoint +// for the client and server and is used to perform fast IPC. Apart from +// communicating data, a channel is also capable of donating file descriptors. +type channel struct { + fdTracker + dead bool + data flipcall.Endpoint + fdChan fdchannel.Endpoint +} + +var _ Communicator = (*channel)(nil) + +// PayloadBuf implements Communicator.PayloadBuf. +func (ch *channel) PayloadBuf(size uint32) []byte { + return ch.data.Data()[chanHeaderLen : chanHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + // Write header. Requests can not donate FDs. + ch.marshalHdr(m, 0 /* numFDs */) + + // One-shot communication. RPCs are expected to be quick rather than block. + rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen) + if err != nil { + // This channel is now unusable. + ch.dead = true + // Map the transport errors to EIO, but also log the real error. + log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err) + return 0, 0, unix.EIO + } + + return ch.rcvMsg(rcvDataLen) +} + +func (ch *channel) shutdown() { + ch.data.Shutdown() +} + +func (ch *channel) destroy() { + ch.dead = true + ch.fdChan.Destroy() + ch.data.Destroy() +} + +// createChannel creates a server side channel. It returns a packet window +// descriptor (for the data channel) and an open socket for the FD channel. +func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + // If c.channels is nil, the connection has closed. + if c.channels == nil || len(c.channels) >= maxChannels() { + return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS + } + ch := &channel{} + + // Set up data channel. + desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize)) + if err != nil { + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + if err := ch.data.Init(flipcall.ServerSide, desc); err != nil { + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + + // Set up FD channel. + fdSocks, err := fdchannel.NewConnectedSockets() + if err != nil { + ch.data.Destroy() + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + ch.fdChan.Init(fdSocks[0]) + clientFDSock := fdSocks[1] + + c.channels = append(c.channels, ch) + return ch, desc, clientFDSock, nil +} + +// sendFDs sends as many FDs as it can. The failure to send an FD does not +// cause an error and fail the entire RPC. FDs are considered supplementary +// responses that are not critical to the RPC response itself. The failure to +// send the (i)th FD will cause all the following FDs to not be sent as well +// because the order in which FDs are donated is important. +func (ch *channel) sendFDs(fds []int) uint8 { + numFDs := len(fds) + if numFDs == 0 { + return 0 + } + + if numFDs > math.MaxUint8 { + log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs) + return 0 + } + + for i, fd := range fds { + if err := ch.fdChan.SendFD(fd); err != nil { + log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err) + return uint8(i) + } + } + return uint8(numFDs) +} + +// channelHeader is the header present in front of each message received on +// flipcall endpoint when the protocol version being used is 1. +// +// +marshal +type channelHeader struct { + message MID + numFDs uint8 + _ uint8 // Need to make struct packed. +} + +func (ch *channel) marshalHdr(m MID, numFDs uint8) { + header := &channelHeader{ + message: m, + numFDs: numFDs, + } + header.MarshalUnsafe(ch.data.Data()) +} + +func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) { + if dataLen < chanHeaderLen { + log.Warningf("received data has size smaller than header length: %d", dataLen) + return 0, 0, unix.EIO + } + + // Read header first. + var header channelHeader + header.UnmarshalUnsafe(ch.data.Data()) + + // Read any FDs. + for i := 0; i < int(header.numFDs); i++ { + fd, err := ch.fdChan.RecvFDNonblock() + if err != nil { + log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err) + break + } + ch.TrackFD(fd) + } + + return header.message, dataLen - chanHeaderLen, nil +} diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go new file mode 100644 index 000000000..ccf1b9f72 --- /dev/null +++ b/pkg/lisafs/client.go @@ -0,0 +1,432 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + "math" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +const ( + // fdsToCloseBatchSize is the number of closed FDs batched before an Close + // RPC is made to close them all. fdsToCloseBatchSize is immutable. + fdsToCloseBatchSize = 100 +) + +// Client helps manage a connection to the lisafs server and pass messages +// efficiently. There is a 1:1 mapping between a Connection and a Client. +type Client struct { + // sockComm is the main socket by which this connections is established. + // Communication over the socket is synchronized by sockMu. + sockMu sync.Mutex + sockComm *sockCommunicator + + // channelsMu protects channels and availableChannels. + channelsMu sync.Mutex + // channels tracks all the channels. + channels []*channel + // availableChannels is a LIFO (stack) of channels available to be used. + availableChannels []*channel + // activeWg represents active channels. + activeWg sync.WaitGroup + + // watchdogWg only holds the watchdog goroutine. + watchdogWg sync.WaitGroup + + // supported caches information about which messages are supported. It is + // indexed by MID. An MID is supported if supported[MID] is true. + supported []bool + + // maxMessageSize is the maximum payload length (in bytes) that can be sent. + // It is initialized on Mount and is immutable. + maxMessageSize uint32 + + // fdsToClose tracks the FDs to close. It caches the FDs no longer being used + // by the client and closes them in one shot. It is not preserved across + // checkpoint/restore as FDIDs are not preserved. + fdsMu sync.Mutex + fdsToClose []FDID +} + +// NewClient creates a new client for communication with the server. It mounts +// the server and creates channels for fast IPC. NewClient takes ownership over +// the passed socket. On success, it returns the initialized client along with +// the root Inode. +func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) { + maxChans := maxChannels() + c := &Client{ + sockComm: newSockComm(sock), + channels: make([]*channel, 0, maxChans), + availableChannels: make([]*channel, 0, maxChans), + maxMessageSize: 1 << 20, // 1 MB for now. + fdsToClose: make([]FDID, 0, fdsToCloseBatchSize), + } + + // Start a goroutine to check socket health. This goroutine is also + // responsible for client cleanup. + c.watchdogWg.Add(1) + go c.watchdog() + + // Clean everything up if anything fails. + cu := cleanup.Make(func() { + c.Close() + }) + defer cu.Clean() + + // Mount the server first. Assume Mount is supported so that we can make the + // Mount RPC below. + c.supported = make([]bool, Mount+1) + c.supported[Mount] = true + mountMsg := MountReq{ + MountPath: SizedString(mountPath), + } + var mountResp MountResp + if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil { + return nil, nil, err + } + + // Initialize client. + c.maxMessageSize = uint32(mountResp.MaxMessageSize) + var maxSuppMID MID + for _, suppMID := range mountResp.SupportedMs { + if suppMID > maxSuppMID { + maxSuppMID = suppMID + } + } + c.supported = make([]bool, maxSuppMID+1) + for _, suppMID := range mountResp.SupportedMs { + c.supported[suppMID] = true + } + + // Create channels parallely so that channels can be used to create more + // channels and costly initialization like flipcall.Endpoint.Connect can + // proceed parallely. + var channelsWg sync.WaitGroup + channelErrs := make([]error, maxChans) + for i := 0; i < maxChans; i++ { + channelsWg.Add(1) + curChanID := i + go func() { + defer channelsWg.Done() + ch, err := c.createChannel() + if err != nil { + log.Warningf("channel creation failed: %v", err) + channelErrs[curChanID] = err + return + } + c.channelsMu.Lock() + c.channels = append(c.channels, ch) + c.availableChannels = append(c.availableChannels, ch) + c.channelsMu.Unlock() + }() + } + channelsWg.Wait() + + for _, channelErr := range channelErrs { + // Return the first non-nil channel creation error. + if channelErr != nil { + return nil, nil, channelErr + } + } + cu.Release() + + return c, &mountResp.Root, nil +} + +func (c *Client) watchdog() { + defer c.watchdogWg.Done() + + events := []unix.PollFd{ + { + Fd: int32(c.sockComm.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == unix.EINTR || err == unix.EAGAIN { + continue + } + if err != nil { + log.Warningf("lisafs.Client.watch(): %v", err) + } else if n != 1 { + log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Shutdown all active channels and wait for them to complete. + c.shutdownActiveChans() + c.activeWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.destroy() + } + c.channelsMu.Unlock() + + // Close main socket. + c.sockComm.destroy() +} + +func (c *Client) shutdownActiveChans() { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + + availableChans := make(map[*channel]bool) + for _, ch := range c.availableChannels { + availableChans[ch] = true + } + for _, ch := range c.channels { + // A channel that is not available is active. + if _, ok := availableChans[ch]; !ok { + log.Debugf("shutting down active channel@%p...", ch) + ch.shutdown() + } + } + + // Prevent channels from becoming available and serving new requests. + c.availableChannels = nil +} + +// Close shuts down the main socket and waits for the watchdog to clean up. +func (c *Client) Close() { + // This shutdown has no effect if the watchdog has already fired and closed + // the main socket. + c.sockComm.shutdown() + c.watchdogWg.Wait() +} + +func (c *Client) createChannel() (*channel, error) { + var chanResp ChannelResp + var fds [2]int + if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil { + return nil, err + } + if fds[0] < 0 || fds[1] < 0 { + closeFDs(fds[:]) + return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds) + } + + // Lets create the channel. + defer closeFDs(fds[:1]) // The data FD is not needed after this. + desc := flipcall.PacketWindowDescriptor{ + FD: fds[0], + Offset: chanResp.dataOffset, + Length: int(chanResp.dataLength), + } + + ch := &channel{} + if err := ch.data.Init(flipcall.ClientSide, desc); err != nil { + closeFDs(fds[1:]) + return nil, err + } + ch.fdChan.Init(fds[1]) // fdChan now owns this FD. + + // Only a connected channel is usable. + if err := ch.data.Connect(); err != nil { + ch.destroy() + return nil, err + } + return ch, nil +} + +// IsSupported returns true if this connection supports the passed message. +func (c *Client) IsSupported(m MID) bool { + return int(m) < len(c.supported) && c.supported[m] +} + +// CloseFDBatched either queues the passed FD to be closed or makes a batch +// RPC to close all the accumulated FDs-to-close. +func (c *Client) CloseFDBatched(ctx context.Context, fd FDID) { + c.fdsMu.Lock() + c.fdsToClose = append(c.fdsToClose, fd) + if len(c.fdsToClose) < fdsToCloseBatchSize { + c.fdsMu.Unlock() + return + } + + // Flush the cache. We should not hold fdsMu while making an RPC, so be sure + // to copy the fdsToClose to another buffer before unlocking fdsMu. + var toCloseArr [fdsToCloseBatchSize]FDID + toClose := toCloseArr[:len(c.fdsToClose)] + copy(toClose, c.fdsToClose) + + // Clear fdsToClose so other FDIDs can be appended. + c.fdsToClose = c.fdsToClose[:0] + c.fdsMu.Unlock() + + req := CloseReq{FDs: toClose} + ctx.UninterruptibleSleepStart(false) + err := c.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + log.Warningf("lisafs: batch closing FDs returned error: %v", err) + } +} + +// SyncFDs makes a Fsync RPC to sync multiple FDs. +func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error { + if len(fds) == 0 { + return nil + } + req := FsyncReq{FDs: fds} + ctx.UninterruptibleSleepStart(false) + err := c.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// SndRcvMessage invokes reqMarshal to marshal the request onto the payload +// buffer, wakes up the server to process the request, waits for the response +// and invokes respUnmarshal with the response payload. respFDs is populated +// with the received FDs, extra fields are set to -1. +// +// Note that the function arguments intentionally accept marshal.Marshallable +// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of +// directly accepting the marshal.Marshallable interface. Even though just +// accepting marshal.Marshallable is cleaner, it leads to a heap allocation +// (even if that interface variable itself does not escape). In other words, +// implicit conversion to an interface leads to an allocation. +// +// Precondition: reqMarshal and respUnmarshal must be non-nil. +func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error { + if !c.IsSupported(m) { + return unix.EOPNOTSUPP + } + if payloadLen > c.maxMessageSize { + log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize) + return unix.EIO + } + wantFDs := len(respFDs) + if wantFDs > math.MaxUint8 { + log.Warningf("want too many FDs: %d", wantFDs) + return unix.EINVAL + } + + // Acquire a communicator. + comm := c.acquireCommunicator() + defer c.releaseCommunicator(comm) + + // Marshal the request into comm's payload buffer and make the RPC. + reqMarshal(comm.PayloadBuf(payloadLen)) + respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs)) + + // Handle FD donation. + rcvFDs := comm.ReleaseFDs() + if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 { + // releasedFDs is memory owned by comm which can not be returned to caller. + // Copy it into the caller's buffer. + numFDCopied := copy(respFDs, rcvFDs) + if numFDCopied < numRcvFDs { + log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs) + closeFDs(rcvFDs[numFDCopied:]) + } + if numFDCopied < wantFDs { + for i := numFDCopied; i < wantFDs; i++ { + respFDs[i] = -1 + } + } + } + + // Error cases. + if err != nil { + closeFDs(respFDs) + return err + } + if respM == Error { + closeFDs(respFDs) + var resp ErrorResp + resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen)) + return unix.Errno(resp.errno) + } + if respM != m { + closeFDs(respFDs) + log.Warningf("sent %d message but got %d in response", m, respM) + return unix.EINVAL + } + + // Success. The payload must be unmarshalled *before* comm is released. + respUnmarshal(comm.PayloadBuf(respPayloadLen)) + return nil +} + +// Postcondition: releaseCommunicator() must be called on the returned value. +func (c *Client) acquireCommunicator() Communicator { + // Prefer using channel over socket because: + // - Channel uses a shared memory region for passing messages. IO from shared + // memory is faster and does not involve making a syscall. + // - No intermediate buffer allocation needed. With a channel, the message + // can be directly pasted into the shared memory region. + if ch := c.getChannel(); ch != nil { + return ch + } + + c.sockMu.Lock() + return c.sockComm +} + +// Precondition: comm must have been acquired via acquireCommunicator(). +func (c *Client) releaseCommunicator(comm Communicator) { + switch t := comm.(type) { + case *sockCommunicator: + c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator(). + case *channel: + c.releaseChannel(t) + default: + panic(fmt.Sprintf("unknown communicator type %T", t)) + } +} + +// getChannel pops a channel from the available channels stack. The caller must +// release the channel after use. +func (c *Client) getChannel() *channel { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + if len(c.availableChannels) == 0 { + return nil + } + + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + c.activeWg.Add(1) + return ch +} + +// releaseChannel pushes the passed channel onto the available channel stack if +// reinsert is true. +func (c *Client) releaseChannel(ch *channel) { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + + // If availableChannels is nil, then watchdog has fired and the client is + // shutting down. So don't make this channel available again. + if !ch.dead && c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.activeWg.Done() +} diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go new file mode 100644 index 000000000..170c15705 --- /dev/null +++ b/pkg/lisafs/client_file.go @@ -0,0 +1,528 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// ClientFD is a wrapper around FDID that provides client-side utilities +// so that RPC making is easier. +type ClientFD struct { + fd FDID + client *Client +} + +// ID returns the underlying FDID. +func (f *ClientFD) ID() FDID { + return f.fd +} + +// Client returns the backing Client. +func (f *ClientFD) Client() *Client { + return f.client +} + +// NewFD initializes a new ClientFD. +func (c *Client) NewFD(fd FDID) ClientFD { + return ClientFD{ + client: c, + fd: fd, + } +} + +// Ok returns true if the underlying FD is ok. +func (f *ClientFD) Ok() bool { + return f.fd.Ok() +} + +// CloseBatched queues this FD to be closed on the server and resets f.fd. +// This maybe invoke the Close RPC if the queue is full. +func (f *ClientFD) CloseBatched(ctx context.Context) { + f.client.CloseFDBatched(ctx, f.fd) + f.fd = InvalidFDID +} + +// Close closes this FD immediately (invoking a Close RPC). Consider using +// CloseBatched if closing this FD on remote right away is not critical. +func (f *ClientFD) Close(ctx context.Context) error { + fdArr := [1]FDID{f.fd} + req := CloseReq{FDs: fdArr[:]} + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// OpenAt makes the OpenAt RPC. +func (f *ClientFD) OpenAt(ctx context.Context, flags uint32) (FDID, int, error) { + req := OpenAtReq{ + FD: f.fd, + Flags: flags, + } + var respFD [1]int + var resp OpenAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(OpenAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, respFD[:]) + ctx.UninterruptibleSleepFinish(false) + return resp.NewFD, respFD[0], err +} + +// OpenCreateAt makes the OpenCreateAt RPC. +func (f *ClientFD) OpenCreateAt(ctx context.Context, name string, flags uint32, mode linux.FileMode, uid UID, gid GID) (Inode, FDID, int, error) { + var req OpenCreateAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Flags = primitive.Uint32(flags) + req.Mode = mode + req.UID = uid + req.GID = gid + + var respFD [1]int + var resp OpenCreateAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(OpenCreateAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, respFD[:]) + ctx.UninterruptibleSleepFinish(false) + return resp.Child, resp.NewFD, respFD[0], err +} + +// StatTo makes the Fstat RPC and populates stat with the result. +func (f *ClientFD) StatTo(ctx context.Context, stat *linux.Statx) error { + req := StatReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FStat, uint32(req.SizeBytes()), req.MarshalUnsafe, stat.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Sync makes the Fsync RPC. +func (f *ClientFD) Sync(ctx context.Context) error { + req := FsyncReq{FDs: []FDID{f.fd}} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// chunkify applies fn to buf in chunks based on chunkSize. +func chunkify(chunkSize uint64, buf []byte, fn func([]byte, uint64) (uint64, error)) (uint64, error) { + toProcess := uint64(len(buf)) + var ( + totalProcessed uint64 + curProcessed uint64 + off uint64 + err error + ) + for { + if totalProcessed == toProcess { + return totalProcessed, nil + } + + if totalProcessed+chunkSize > toProcess { + curProcessed, err = fn(buf[totalProcessed:], off) + } else { + curProcessed, err = fn(buf[totalProcessed:totalProcessed+chunkSize], off) + } + totalProcessed += curProcessed + off += curProcessed + + if err != nil { + return totalProcessed, err + } + + // Return partial result immediately. + if curProcessed < chunkSize { + return totalProcessed, nil + } + + // If we received more bytes than we ever requested, this is a problem. + if totalProcessed > toProcess { + panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", totalProcessed, toProcess)) + } + } +} + +// Read makes the PRead RPC. +func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) { + var resp PReadResp + // maxDataReadSize represents the maximum amount of data we can read at once + // (maximum message size - metadata size present in resp). Uninitialized + // resp.SizeBytes() correctly returns the metadata size only (since the read + // buffer is empty). + maxDataReadSize := uint64(f.client.maxMessageSize) - uint64(resp.SizeBytes()) + return chunkify(maxDataReadSize, dst, func(buf []byte, curOff uint64) (uint64, error) { + req := PReadReq{ + Offset: offset + curOff, + FD: f.fd, + Count: uint32(len(buf)), + } + + // This will be unmarshalled into. Already set Buf so that we don't need to + // allocate a temporary buffer during unmarshalling. + // PReadResp.UnmarshalBytes expects this to be set. + resp.Buf = buf + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return uint64(resp.NumBytes), err + }) +} + +// Write makes the PWrite RPC. +func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) { + var req PWriteReq + // maxDataWriteSize represents the maximum amount of data we can write at + // once (maximum message size - metadata size present in req). Uninitialized + // req.SizeBytes() correctly returns the metadata size only (since the write + // buffer is empty). + maxDataWriteSize := uint64(f.client.maxMessageSize) - uint64(req.SizeBytes()) + return chunkify(maxDataWriteSize, src, func(buf []byte, curOff uint64) (uint64, error) { + req = PWriteReq{ + Offset: primitive.Uint64(offset + curOff), + FD: f.fd, + NumBytes: primitive.Uint32(len(buf)), + Buf: buf, + } + + var resp PWriteResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Count, err + }) +} + +// MkdirAt makes the MkdirAt RPC. +func (f *ClientFD) MkdirAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID) (*Inode, error) { + var req MkdirAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Mode = mode + req.UID = uid + req.GID = gid + + var resp MkdirAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(MkdirAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.ChildDir, err +} + +// SymlinkAt makes the SymlinkAt RPC. +func (f *ClientFD) SymlinkAt(ctx context.Context, name, target string, uid UID, gid GID) (*Inode, error) { + req := SymlinkAtReq{ + DirFD: f.fd, + Name: SizedString(name), + Target: SizedString(target), + UID: uid, + GID: gid, + } + + var resp SymlinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(SymlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Symlink, err +} + +// LinkAt makes the LinkAt RPC. +func (f *ClientFD) LinkAt(ctx context.Context, targetFD FDID, name string) (*Inode, error) { + req := LinkAtReq{ + DirFD: f.fd, + Target: targetFD, + Name: SizedString(name), + } + + var resp LinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(LinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Link, err +} + +// MknodAt makes the MknodAt RPC. +func (f *ClientFD) MknodAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID, minor, major uint32) (*Inode, error) { + var req MknodAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Mode = mode + req.UID = uid + req.GID = gid + req.Minor = primitive.Uint32(minor) + req.Major = primitive.Uint32(major) + + var resp MknodAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(MknodAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Child, err +} + +// SetStat makes the SetStat RPC. +func (f *ClientFD) SetStat(ctx context.Context, stat *linux.Statx) (uint32, error, error) { + req := SetStatReq{ + FD: f.fd, + Mask: stat.Mask, + Mode: uint32(stat.Mode), + UID: UID(stat.UID), + GID: GID(stat.GID), + Size: stat.Size, + Atime: linux.Timespec{ + Sec: stat.Atime.Sec, + Nsec: int64(stat.Atime.Nsec), + }, + Mtime: linux.Timespec{ + Sec: stat.Mtime.Sec, + Nsec: int64(stat.Mtime.Nsec), + }, + } + + var resp SetStatResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(SetStat, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.FailureMask, unix.Errno(resp.FailureErrNo), err +} + +// WalkMultiple makes the Walk RPC with multiple path components. +func (f *ClientFD) WalkMultiple(ctx context.Context, names []string) (WalkStatus, []Inode, error) { + req := WalkReq{ + DirFD: f.fd, + Path: StringArray(names), + } + + var resp WalkResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Status, resp.Inodes, err +} + +// Walk makes the Walk RPC with just one path component to walk. +func (f *ClientFD) Walk(ctx context.Context, name string) (*Inode, error) { + req := WalkReq{ + DirFD: f.fd, + Path: []string{name}, + } + + var inode [1]Inode + resp := WalkResp{Inodes: inode[:]} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return nil, err + } + + switch resp.Status { + case WalkComponentDoesNotExist: + return nil, unix.ENOENT + case WalkComponentSymlink: + // f is not a directory which can be walked on. + return nil, unix.ENOTDIR + } + + if n := len(resp.Inodes); n > 1 { + for i := range resp.Inodes { + f.client.CloseFDBatched(ctx, resp.Inodes[i].ControlFD) + } + log.Warningf("requested to walk one component, but got %d results", n) + return nil, unix.EIO + } else if n == 0 { + log.Warningf("walk has success status but no results returned") + return nil, unix.ENOENT + } + return &inode[0], err +} + +// WalkStat makes the WalkStat RPC with multiple path components to walk. +func (f *ClientFD) WalkStat(ctx context.Context, names []string) ([]linux.Statx, error) { + req := WalkReq{ + DirFD: f.fd, + Path: StringArray(names), + } + + var resp WalkStatResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(WalkStat, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Stats, err +} + +// StatFSTo makes the FStatFS RPC and populates statFS with the result. +func (f *ClientFD) StatFSTo(ctx context.Context, statFS *StatFS) error { + req := FStatFSReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FStatFS, uint32(req.SizeBytes()), req.MarshalUnsafe, statFS.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Allocate makes the FAllocate RPC. +func (f *ClientFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + req := FAllocateReq{ + FD: f.fd, + Mode: mode, + Offset: offset, + Length: length, + } + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FAllocate, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// ReadLinkAt makes the ReadLinkAt RPC. +func (f *ClientFD) ReadLinkAt(ctx context.Context) (string, error) { + req := ReadLinkAtReq{FD: f.fd} + var resp ReadLinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(ReadLinkAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return string(resp.Target), err +} + +// Flush makes the Flush RPC. +func (f *ClientFD) Flush(ctx context.Context) error { + if !f.client.IsSupported(Flush) { + // If Flush is not supported, it probably means that it would be a noop. + return nil + } + req := FlushReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Flush, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Connect makes the Connect RPC. +func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType) (int, error) { + req := ConnectReq{FD: f.fd, SockType: uint32(sockType)} + var sockFD [1]int + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, sockFD[:]) + ctx.UninterruptibleSleepFinish(false) + if err == nil && sockFD[0] < 0 { + err = unix.EBADF + } + return sockFD[0], err +} + +// UnlinkAt makes the UnlinkAt RPC. +func (f *ClientFD) UnlinkAt(ctx context.Context, name string, flags uint32) error { + req := UnlinkAtReq{ + DirFD: f.fd, + Name: SizedString(name), + Flags: primitive.Uint32(flags), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(UnlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// RenameTo makes the RenameAt RPC which renames f to newDirFD directory with +// name newName. +func (f *ClientFD) RenameTo(ctx context.Context, newDirFD FDID, newName string) error { + req := RenameAtReq{ + Renamed: f.fd, + NewDir: newDirFD, + NewName: SizedString(newName), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(RenameAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Getdents64 makes the Getdents64 RPC. +func (f *ClientFD) Getdents64(ctx context.Context, count int32) ([]Dirent64, error) { + req := Getdents64Req{ + DirFD: f.fd, + Count: count, + } + + var resp Getdents64Resp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Getdents64, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Dirents, err +} + +// ListXattr makes the FListXattr RPC. +func (f *ClientFD) ListXattr(ctx context.Context, size uint64) ([]string, error) { + req := FListXattrReq{ + FD: f.fd, + Size: size, + } + + var resp FListXattrResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FListXattr, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Xattrs, err +} + +// GetXattr makes the FGetXattr RPC. +func (f *ClientFD) GetXattr(ctx context.Context, name string, size uint64) (string, error) { + req := FGetXattrReq{ + FD: f.fd, + Name: SizedString(name), + BufSize: primitive.Uint32(size), + } + + var resp FGetXattrResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FGetXattr, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return string(resp.Value), err +} + +// SetXattr makes the FSetXattr RPC. +func (f *ClientFD) SetXattr(ctx context.Context, name string, value string, flags uint32) error { + req := FSetXattrReq{ + FD: f.fd, + Name: SizedString(name), + Value: SizedString(value), + Flags: primitive.Uint32(flags), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FSetXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// RemoveXattr makes the FRemoveXattr RPC. +func (f *ClientFD) RemoveXattr(ctx context.Context, name string) error { + req := FRemoveXattrReq{ + FD: f.fd, + Name: SizedString(name), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FRemoveXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} diff --git a/pkg/lisafs/communicator.go b/pkg/lisafs/communicator.go new file mode 100644 index 000000000..ec2035158 --- /dev/null +++ b/pkg/lisafs/communicator.go @@ -0,0 +1,80 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import "golang.org/x/sys/unix" + +// Communicator is a server side utility which represents exactly how the +// server is communicating with the client. +type Communicator interface { + // PayloadBuf returns a slice to the payload section of its internal buffer + // where the message can be marshalled. The handlers should use this to + // populate the payload buffer with the message. + // + // The payload buffer contents *should* be preserved across calls with + // different sizes. Note that this is not a guarantee, because a compromised + // owner of a "shared" payload buffer can tamper with its contents anytime, + // even when it's not its turn to do so. + PayloadBuf(size uint32) []byte + + // SndRcvMessage sends message m. The caller must have populated PayloadBuf() + // with payloadLen bytes. The caller expects to receive wantFDs FDs. + // Any received FDs must be accessible via ReleaseFDs(). It returns the + // response message along with the response payload length. + SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) + + // DonateFD makes fd non-blocking and starts tracking it. The next call to + // ReleaseFDs will include fd in the order it was added. Communicator takes + // ownership of fd. Server side should call this. + DonateFD(fd int) error + + // Track starts tracking fd. The next call to ReleaseFDs will include fd in + // the order it was added. Communicator takes ownership of fd. Client side + // should use this for accumulating received FDs. + TrackFD(fd int) + + // ReleaseFDs returns the accumulated FDs and stops tracking them. The + // ownership of the FDs is transferred to the caller. + ReleaseFDs() []int +} + +// fdTracker is a partial implementation of Communicator. It can be embedded in +// Communicator implementations to keep track of FD donations. +type fdTracker struct { + fds []int +} + +// DonateFD implements Communicator.DonateFD. +func (d *fdTracker) DonateFD(fd int) error { + // Make sure the FD is non-blocking. + if err := unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return err + } + d.TrackFD(fd) + return nil +} + +// TrackFD implements Communicator.TrackFD. +func (d *fdTracker) TrackFD(fd int) { + d.fds = append(d.fds, fd) +} + +// ReleaseFDs implements Communicator.ReleaseFDs. +func (d *fdTracker) ReleaseFDs() []int { + ret := d.fds + d.fds = d.fds[:0] + return ret +} diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go new file mode 100644 index 000000000..f6e5ecb4f --- /dev/null +++ b/pkg/lisafs/connection.go @@ -0,0 +1,320 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// Connection represents a connection between a mount point in the client and a +// mount point in the server. It is owned by the server on which it was started +// and facilitates communication with the client mount. +// +// Each connection is set up using a unix domain socket. One end is owned by +// the server and the other end is owned by the client. The connection may +// spawn additional comunicational channels for the same mount for increased +// RPC concurrency. +type Connection struct { + // server is the server on which this connection was created. It is immutably + // associated with it for its entire lifetime. + server *Server + + // mounted is a one way flag indicating whether this connection has been + // mounted correctly and the server is initialized properly. + mounted bool + + // readonly indicates if this connection is readonly. All write operations + // will fail with EROFS. + readonly bool + + // sockComm is the main socket by which this connections is established. + sockComm *sockCommunicator + + // channelsMu protects channels. + channelsMu sync.Mutex + // channels keeps track of all open channels. + channels []*channel + + // activeWg represents active channels. + activeWg sync.WaitGroup + + // reqGate counts requests that are still being handled. + reqGate sync.Gate + + // channelAlloc is used to allocate memory for channels. + channelAlloc *flipcall.PacketWindowAllocator + + fdsMu sync.RWMutex + // fds keeps tracks of open FDs on this server. It is protected by fdsMu. + fds map[FDID]genericFD + // nextFDID is the next available FDID. It is protected by fdsMu. + nextFDID FDID +} + +// CreateConnection initializes a new connection - creating a server if +// required. The connection must be started separately. +func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) { + c := &Connection{ + sockComm: newSockComm(sock), + server: s, + readonly: readonly, + channels: make([]*channel, 0, maxChannels()), + fds: make(map[FDID]genericFD), + nextFDID: InvalidFDID + 1, + } + + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return nil, err + } + c.channelAlloc = alloc + return c, nil +} + +// Server returns the associated server. +func (c *Connection) Server() *Server { + return c.server +} + +// ServerImpl returns the associated server implementation. +func (c *Connection) ServerImpl() ServerImpl { + return c.server.impl +} + +// Run defines the lifecycle of a connection. +func (c *Connection) Run() { + defer c.close() + + // Start handling requests on this connection. + for { + m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */) + if err != nil { + log.Debugf("sock read failed, closing connection: %v", err) + return + } + + respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen) + err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs) + closeFDs(respFDs) + if err != nil { + log.Debugf("sock write failed, closing connection: %v", err) + return + } + } +} + +// service starts servicing the passed channel until the channel is shutdown. +// This is a blocking method and hence must be called in a separate goroutine. +func (c *Connection) service(ch *channel) error { + rcvDataLen, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rcvDataLen > 0 { + m, payloadLen, err := ch.rcvMsg(rcvDataLen) + if err != nil { + return err + } + respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen) + numFDs := ch.sendFDs(respFDs) + closeFDs(respFDs) + + ch.marshalHdr(respM, numFDs) + rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen) + if err != nil { + return err + } + } + return nil +} + +func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) { + resp := &ErrorResp{errno: uint32(err)} + respLen := uint32(resp.SizeBytes()) + resp.MarshalUnsafe(comm.PayloadBuf(respLen)) + return Error, respLen, nil +} + +func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) { + if !c.reqGate.Enter() { + // c.close() has been called; the connection is shutting down. + return c.respondError(comm, unix.ECONNRESET) + } + defer c.reqGate.Leave() + + if !c.mounted && m != Mount { + log.Warningf("connection must first be mounted") + return c.respondError(comm, unix.EINVAL) + } + + // Check if the message is supported for forward compatibility. + if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil { + log.Warningf("received request which is not supported by the server, MID = %d", m) + return c.respondError(comm, unix.EOPNOTSUPP) + } + + // Try handling the request. + respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen) + fds := comm.ReleaseFDs() + if err != nil { + closeFDs(fds) + return c.respondError(comm, p9.ExtractErrno(err)) + } + + return m, respPayloadLen, fds +} + +func (c *Connection) close() { + // Wait for completion of all inflight requests. This is mostly so that if + // a request is stuck, the sandbox supervisor has the opportunity to kill + // us with SIGABRT to get a stack dump of the offending handler. + c.reqGate.Close() + + // Shutdown and clean up channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.shutdown() + } + c.activeWg.Wait() + for _, ch := range c.channels { + ch.destroy() + } + // This is to prevent additional channels from being created. + c.channels = nil + c.channelsMu.Unlock() + + // Free the channel memory. + if c.channelAlloc != nil { + c.channelAlloc.Destroy() + } + + // Ensure the connection is closed. + c.sockComm.destroy() + + // Cleanup all FDs. + c.fdsMu.Lock() + for fdid := range c.fds { + fd := c.removeFDLocked(fdid) + fd.DecRef(nil) // Drop the ref held by c. + } + c.fdsMu.Unlock() +} + +// The caller gains a ref on the FD on success. +func (c *Connection) lookupFD(id FDID) (genericFD, error) { + c.fdsMu.RLock() + defer c.fdsMu.RUnlock() + + fd, ok := c.fds[id] + if !ok { + return nil, unix.EBADF + } + fd.IncRef() + return fd, nil +} + +// LookupControlFD retrieves the control FD identified by id on this +// connection. On success, the caller gains a ref on the FD. +func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) { + fd, err := c.lookupFD(id) + if err != nil { + return nil, err + } + + cfd, ok := fd.(*ControlFD) + if !ok { + fd.DecRef(nil) + return nil, unix.EINVAL + } + return cfd, nil +} + +// LookupOpenFD retrieves the open FD identified by id on this +// connection. On success, the caller gains a ref on the FD. +func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) { + fd, err := c.lookupFD(id) + if err != nil { + return nil, err + } + + ofd, ok := fd.(*OpenFD) + if !ok { + fd.DecRef(nil) + return nil, unix.EINVAL + } + return ofd, nil +} + +// insertFD inserts the passed fd into the internal datastructure to track FDs. +// The caller must hold a ref on fd which is transferred to the connection. +func (c *Connection) insertFD(fd genericFD) FDID { + c.fdsMu.Lock() + defer c.fdsMu.Unlock() + + res := c.nextFDID + c.nextFDID++ + if c.nextFDID < res { + panic("ran out of FDIDs") + } + c.fds[res] = fd + return res +} + +// RemoveFD makes c stop tracking the passed FDID and drops its ref on it. +func (c *Connection) RemoveFD(id FDID) { + c.fdsMu.Lock() + fd := c.removeFDLocked(id) + c.fdsMu.Unlock() + if fd != nil { + // Drop the ref held by c. This can take arbitrarily long. So do not hold + // c.fdsMu while calling it. + fd.DecRef(nil) + } +} + +// RemoveControlFDLocked is the same as RemoveFD with added preconditions. +// +// Preconditions: +// * server's rename mutex must at least be read locked. +// * id must be pointing to a control FD. +func (c *Connection) RemoveControlFDLocked(id FDID) { + c.fdsMu.Lock() + fd := c.removeFDLocked(id) + c.fdsMu.Unlock() + if fd != nil { + // Drop the ref held by c. This can take arbitrarily long. So do not hold + // c.fdsMu while calling it. + fd.(*ControlFD).DecRefLocked() + } +} + +// removeFDLocked makes c stop tracking the passed FDID. Note that the caller +// must drop ref on the returned fd (preferably without holding c.fdsMu). +// +// Precondition: c.fdsMu is locked. +func (c *Connection) removeFDLocked(id FDID) genericFD { + fd := c.fds[id] + if fd == nil { + log.Warningf("removeFDLocked called on non-existent FDID %d", id) + return nil + } + delete(c.fds, id) + return fd +} diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go new file mode 100644 index 000000000..28ba47112 --- /dev/null +++ b/pkg/lisafs/connection_test.go @@ -0,0 +1,194 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connection_test + +import ( + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/lisafs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +const ( + dynamicMsgID = lisafs.Channel + 1 + versionMsgID = dynamicMsgID + 1 +) + +var handlers = [...]lisafs.RPCHandler{ + lisafs.Error: lisafs.ErrorHandler, + lisafs.Mount: lisafs.MountHandler, + lisafs.Channel: lisafs.ChannelHandler, + dynamicMsgID: dynamicMsgHandler, + versionMsgID: versionHandler, +} + +// testServer implements lisafs.ServerImpl. +type testServer struct { + lisafs.Server +} + +var _ lisafs.ServerImpl = (*testServer)(nil) + +type testControlFD struct { + lisafs.ControlFD + lisafs.ControlFDImpl +} + +func (fd *testControlFD) FD() *lisafs.ControlFD { + return &fd.ControlFD +} + +// Mount implements lisafs.Mount. +func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) { + return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil +} + +// MaxMessageSize implements lisafs.MaxMessageSize. +func (s *testServer) MaxMessageSize() uint32 { + return lisafs.MaxMessageSize() +} + +// SupportedMessages implements lisafs.ServerImpl.SupportedMessages. +func (s *testServer) SupportedMessages() []lisafs.MID { + return []lisafs.MID{ + lisafs.Mount, + lisafs.Channel, + dynamicMsgID, + versionMsgID, + } +} + +func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) { + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + + ts := &testServer{} + ts.Server.InitTestOnly(ts, handlers[:]) + conn, err := ts.CreateConnection(serverSocket, false /* readonly */) + if err != nil { + t.Fatalf("starting connection failed: %v", err) + return + } + ts.StartConnection(conn) + + c, _, err := lisafs.NewClient(clientSocket, "/") + if err != nil { + t.Fatalf("client creation failed: %v", err) + } + + clientFn(c) + + c.Close() // This should trigger client and server shutdown. + ts.Wait() +} + +// TestStartUp tests that the server and client can be started up correctly. +func TestStartUp(t *testing.T) { + runServerClient(t, func(c *lisafs.Client) { + if c.IsSupported(lisafs.Error) { + t.Errorf("sending error messages should not be supported") + } + }) +} + +func TestUnsupportedMessage(t *testing.T) { + unsupportedM := lisafs.MID(len(handlers)) + runServerClient(t, func(c *lisafs.Client) { + if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP { + t.Errorf("expected EOPNOTSUPP but got err: %v", err) + } + }) +} + +func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { + var req lisafs.MsgDynamic + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Just echo back the message. + respPayloadLen := uint32(req.SizeBytes()) + req.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// TestStress stress tests sending many messages from various goroutines. +func TestStress(t *testing.T) { + runServerClient(t, func(c *lisafs.Client) { + concurrency := 8 + numMsgPerGoroutine := 5000 + var clientWg sync.WaitGroup + for i := 0; i < concurrency; i++ { + clientWg.Add(1) + go func() { + defer clientWg.Done() + + for j := 0; j < numMsgPerGoroutine; j++ { + // Create a massive random message. + var req lisafs.MsgDynamic + req.Randomize(100) + + var resp lisafs.MsgDynamic + if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil { + t.Errorf("SndRcvMessage: received unexpected error %v", err) + return + } + if !reflect.DeepEqual(&req, &resp) { + t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp) + } + } + }() + } + + clientWg.Wait() + }) +} + +func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { + // To be fair, usually handlers will create their own objects and return a + // pointer to those. Might be tempting to reuse above variables, but don't. + var rv lisafs.P9Version + rv.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Create a new response. + sv := lisafs.P9Version{ + MSize: rv.MSize, + Version: "9P2000.L.Google.11", + } + respPayloadLen := uint32(sv.SizeBytes()) + sv.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel. +func BenchmarkSendRecv(b *testing.B) { + b.ReportAllocs() + sendV := lisafs.P9Version{ + MSize: 1 << 20, + Version: "9P2000.L.Google.12", + } + + var recvV lisafs.P9Version + runServerClient(b, func(c *lisafs.Client) { + for i := 0; i < b.N; i++ { + if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil { + b.Fatalf("unexpected error occurred: %v", err) + } + } + }) +} diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go new file mode 100644 index 000000000..cc6919a1b --- /dev/null +++ b/pkg/lisafs/fd.go @@ -0,0 +1,374 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/sync" +) + +// FDID (file descriptor identifier) is used to identify FDs on a connection. +// Each connection has its own FDID namespace. +// +// +marshal slice:FDIDSlice +type FDID uint32 + +// InvalidFDID represents an invalid FDID. +const InvalidFDID FDID = 0 + +// Ok returns true if f is a valid FDID. +func (f FDID) Ok() bool { + return f != InvalidFDID +} + +// genericFD can represent a ControlFD or OpenFD. +type genericFD interface { + refsvfs2.RefCounter +} + +// A ControlFD is the gateway to the backing filesystem tree node. It is an +// unusual concept. This exists to provide a safe way to do path-based +// operations on the file. It performs operations that can modify the +// filesystem tree and synchronizes these operations. See ControlFDImpl for +// supported operations. +// +// It is not an inode, because multiple control FDs are allowed to exist on the +// same file. It is not a file descriptor because it is not tied to any access +// mode, i.e. a control FD can change its access mode based on the operation +// being performed. +// +// Reference Model: +// * When a control FD is created, the connection takes a ref on it which +// represents the client's ref on the FD. +// * The client can drop its ref via the Close RPC which will in turn make the +// connection drop its ref. +// * Each control FD holds a ref on its parent for its entire life time. +type ControlFD struct { + controlFDRefs + controlFDEntry + + // parent is the parent directory FD containing the file this FD represents. + // A ControlFD holds a ref on parent for its entire lifetime. If this FD + // represents the root, then parent is nil. parent may be a control FD from + // another connection (another mount point). parent is protected by the + // backing server's rename mutex. + parent *ControlFD + + // name is the file path's last component name. If this FD represents the + // root directory, then name is the mount path. name is protected by the + // backing server's rename mutex. + name string + + // children is a linked list of all children control FDs. As per reference + // model, all children hold a ref on this FD. + // children is protected by childrenMu and server's rename mutex. To have + // mutual exclusion, it is sufficient to: + // * Hold rename mutex for reading and lock childrenMu. OR + // * Or hold rename mutex for writing. + childrenMu sync.Mutex + children controlFDList + + // openFDs is a linked list of all FDs opened on this FD. As per reference + // model, all open FDs hold a ref on this FD. + openFDsMu sync.RWMutex + openFDs openFDList + + // All the following fields are immutable. + + // id is the unique FD identifier which identifies this FD on its connection. + id FDID + + // conn is the backing connection owning this FD. + conn *Connection + + // ftype is the file type of the backing inode. ftype.FileType() == ftype. + ftype linux.FileMode + + // impl is the control FD implementation which embeds this struct. It + // contains all the implementation specific details. + impl ControlFDImpl +} + +var _ genericFD = (*ControlFD)(nil) + +// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context +// parameter should never be used. It exists solely to comply with the +// refsvfs2.RefCounter interface. +func (fd *ControlFD) DecRef(context.Context) { + fd.controlFDRefs.DecRef(func() { + if fd.parent != nil { + fd.conn.server.RenameMu.RLock() + fd.parent.childrenMu.Lock() + fd.parent.children.Remove(fd) + fd.parent.childrenMu.Unlock() + fd.conn.server.RenameMu.RUnlock() + fd.parent.DecRef(nil) // Drop the ref on the parent. + } + fd.impl.Close(fd.conn) + }) +} + +// DecRefLocked is the same as DecRef except the added precondition. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) DecRefLocked() { + fd.controlFDRefs.DecRef(func() { + fd.clearParentLocked() + fd.impl.Close(fd.conn) + }) +} + +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) clearParentLocked() { + if fd.parent == nil { + return + } + fd.parent.childrenMu.Lock() + fd.parent.children.Remove(fd) + fd.parent.childrenMu.Unlock() + fd.parent.DecRefLocked() // Drop the ref on the parent. +} + +// Init must be called before first use of fd. It inserts fd into the +// filesystem tree. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) Init(c *Connection, parent *ControlFD, name string, mode linux.FileMode, impl ControlFDImpl) { + // Initialize fd with 1 ref which is transferred to c via c.insertFD(). + fd.controlFDRefs.InitRefs() + fd.conn = c + fd.id = c.insertFD(fd) + fd.name = name + fd.ftype = mode.FileType() + fd.impl = impl + fd.setParentLocked(parent) +} + +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) setParentLocked(parent *ControlFD) { + fd.parent = parent + if parent != nil { + parent.IncRef() // Hold a ref on parent. + parent.childrenMu.Lock() + parent.children.PushBack(fd) + parent.childrenMu.Unlock() + } +} + +// FileType returns the file mode only containing the file type bits. +func (fd *ControlFD) FileType() linux.FileMode { + return fd.ftype +} + +// IsDir indicates whether fd represents a directory. +func (fd *ControlFD) IsDir() bool { + return fd.ftype == unix.S_IFDIR +} + +// IsRegular indicates whether fd represents a regular file. +func (fd *ControlFD) IsRegular() bool { + return fd.ftype == unix.S_IFREG +} + +// IsSymlink indicates whether fd represents a symbolic link. +func (fd *ControlFD) IsSymlink() bool { + return fd.ftype == unix.S_IFLNK +} + +// IsSocket indicates whether fd represents a socket. +func (fd *ControlFD) IsSocket() bool { + return fd.ftype == unix.S_IFSOCK +} + +// NameLocked returns the backing file's last component name. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) NameLocked() string { + return fd.name +} + +// ParentLocked returns the parent control FD. Note that parent might be a +// control FD from another connection on this server. So its ID must not +// returned on this connection because FDIDs are local to their connection. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) ParentLocked() ControlFDImpl { + if fd.parent == nil { + return nil + } + return fd.parent.impl +} + +// ID returns fd's ID. +func (fd *ControlFD) ID() FDID { + return fd.id +} + +// FilePath returns the absolute path of the file fd was opened on. This is +// expensive and must not be called on hot paths. FilePath acquires the rename +// mutex for reading so callers should not be holding it. +func (fd *ControlFD) FilePath() string { + // Lock the rename mutex for reading to ensure that the filesystem tree is not + // changed while we traverse it upwards. + fd.conn.server.RenameMu.RLock() + defer fd.conn.server.RenameMu.RUnlock() + return fd.FilePathLocked() +} + +// FilePathLocked is the same as FilePath with the additional precondition. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) FilePathLocked() string { + // Walk upwards and prepend name to res. + var res fspath.Builder + for fd != nil { + res.PrependComponent(fd.name) + fd = fd.parent + } + return res.String() +} + +// ForEachOpenFD executes fn on each FD opened on fd. +func (fd *ControlFD) ForEachOpenFD(fn func(ofd OpenFDImpl)) { + fd.openFDsMu.RLock() + defer fd.openFDsMu.RUnlock() + for ofd := fd.openFDs.Front(); ofd != nil; ofd = ofd.Next() { + fn(ofd.impl) + } +} + +// OpenFD represents an open file descriptor on the protocol. It resonates +// closely with a Linux file descriptor. Its operations are limited to the +// file. Its operations are not allowed to modify or traverse the filesystem +// tree. See OpenFDImpl for the supported operations. +// +// Reference Model: +// * An OpenFD takes a reference on the control FD it was opened on. +type OpenFD struct { + openFDRefs + openFDEntry + + // All the following fields are immutable. + + // controlFD is the ControlFD on which this FD was opened. OpenFD holds a ref + // on controlFD for its entire lifetime. + controlFD *ControlFD + + // id is the unique FD identifier which identifies this FD on its connection. + id FDID + + // Access mode for this FD. + readable bool + writable bool + + // impl is the open FD implementation which embeds this struct. It + // contains all the implementation specific details. + impl OpenFDImpl +} + +var _ genericFD = (*OpenFD)(nil) + +// ID returns fd's ID. +func (fd *OpenFD) ID() FDID { + return fd.id +} + +// ControlFD returns the control FD on which this FD was opened. +func (fd *OpenFD) ControlFD() ControlFDImpl { + return fd.controlFD.impl +} + +// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context +// parameter should never be used. It exists solely to comply with the +// refsvfs2.RefCounter interface. +func (fd *OpenFD) DecRef(context.Context) { + fd.openFDRefs.DecRef(func() { + fd.controlFD.openFDsMu.Lock() + fd.controlFD.openFDs.Remove(fd) + fd.controlFD.openFDsMu.Unlock() + fd.controlFD.DecRef(nil) // Drop the ref on the control FD. + fd.impl.Close(fd.controlFD.conn) + }) +} + +// Init must be called before first use of fd. +func (fd *OpenFD) Init(cfd *ControlFD, flags uint32, impl OpenFDImpl) { + // Initialize fd with 1 ref which is transferred to c via c.insertFD(). + fd.openFDRefs.InitRefs() + fd.controlFD = cfd + fd.id = cfd.conn.insertFD(fd) + accessMode := flags & unix.O_ACCMODE + fd.readable = accessMode == unix.O_RDONLY || accessMode == unix.O_RDWR + fd.writable = accessMode == unix.O_WRONLY || accessMode == unix.O_RDWR + fd.impl = impl + cfd.IncRef() // Holds a ref on cfd for its lifetime. + cfd.openFDsMu.Lock() + cfd.openFDs.PushBack(fd) + cfd.openFDsMu.Unlock() +} + +// ControlFDImpl contains implementation details for a ControlFD. +// Implementations of ControlFDImpl should contain their associated ControlFD +// by value as their first field. +// +// The operations that perform path traversal or any modification to the +// filesystem tree must synchronize those modifications with the server's +// rename mutex. +type ControlFDImpl interface { + FD() *ControlFD + Close(c *Connection) + Stat(c *Connection, comm Communicator) (uint32, error) + SetStat(c *Connection, comm Communicator, stat SetStatReq) (uint32, error) + Walk(c *Connection, comm Communicator, path StringArray) (uint32, error) + WalkStat(c *Connection, comm Communicator, path StringArray) (uint32, error) + Open(c *Connection, comm Communicator, flags uint32) (uint32, error) + OpenCreate(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, flags uint32) (uint32, error) + Mkdir(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string) (uint32, error) + Mknod(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, minor uint32, major uint32) (uint32, error) + Symlink(c *Connection, comm Communicator, name string, target string, uid UID, gid GID) (uint32, error) + Link(c *Connection, comm Communicator, dir ControlFDImpl, name string) (uint32, error) + StatFS(c *Connection, comm Communicator) (uint32, error) + Readlink(c *Connection, comm Communicator) (uint32, error) + Connect(c *Connection, comm Communicator, sockType uint32) error + Unlink(c *Connection, name string, flags uint32) error + RenameLocked(c *Connection, newDir ControlFDImpl, newName string) (func(ControlFDImpl), func(), error) + GetXattr(c *Connection, comm Communicator, name string, size uint32) (uint32, error) + SetXattr(c *Connection, name string, value string, flags uint32) error + ListXattr(c *Connection, comm Communicator, size uint64) (uint32, error) + RemoveXattr(c *Connection, comm Communicator, name string) error +} + +// OpenFDImpl contains implementation details for a OpenFD. Implementations of +// OpenFDImpl should contain their associated OpenFD by value as their first +// field. +// +// Since these operations do not perform any path traversal or any modification +// to the filesystem tree, there is no need to synchronize with rename +// operations. +type OpenFDImpl interface { + FD() *OpenFD + Close(c *Connection) + Stat(c *Connection, comm Communicator) (uint32, error) + Sync(c *Connection) error + Write(c *Connection, comm Communicator, buf []byte, off uint64) (uint32, error) + Read(c *Connection, comm Communicator, off uint64, count uint32) (uint32, error) + Allocate(c *Connection, mode, off, length uint64) error + Flush(c *Connection) error + Getdent64(c *Connection, comm Communicator, count uint32, seek0 bool) (uint32, error) +} diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go new file mode 100644 index 000000000..82807734d --- /dev/null +++ b/pkg/lisafs/handlers.go @@ -0,0 +1,768 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + "path" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +const ( + allowedOpenFlags = unix.O_ACCMODE | unix.O_TRUNC + setStatSupportedMask = unix.STATX_MODE | unix.STATX_UID | unix.STATX_GID | unix.STATX_SIZE | unix.STATX_ATIME | unix.STATX_MTIME +) + +// RPCHandler defines a handler that is invoked when the associated message is +// received. The handler is responsible for: +// +// * Unmarshalling the request from the passed payload and interpreting it. +// * Marshalling the response into the communicator's payload buffer. +// * Return the number of payload bytes written. +// * Donate any FDs (if needed) to comm which will in turn donate it to client. +type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) + +var handlers = [...]RPCHandler{ + Error: ErrorHandler, + Mount: MountHandler, + Channel: ChannelHandler, + FStat: FStatHandler, + SetStat: SetStatHandler, + Walk: WalkHandler, + WalkStat: WalkStatHandler, + OpenAt: OpenAtHandler, + OpenCreateAt: OpenCreateAtHandler, + Close: CloseHandler, + FSync: FSyncHandler, + PWrite: PWriteHandler, + PRead: PReadHandler, + MkdirAt: MkdirAtHandler, + MknodAt: MknodAtHandler, + SymlinkAt: SymlinkAtHandler, + LinkAt: LinkAtHandler, + FStatFS: FStatFSHandler, + FAllocate: FAllocateHandler, + ReadLinkAt: ReadLinkAtHandler, + Flush: FlushHandler, + Connect: ConnectHandler, + UnlinkAt: UnlinkAtHandler, + RenameAt: RenameAtHandler, + Getdents64: Getdents64Handler, + FGetXattr: FGetXattrHandler, + FSetXattr: FSetXattrHandler, + FListXattr: FListXattrHandler, + FRemoveXattr: FRemoveXattrHandler, +} + +// ErrorHandler handles Error message. +func ErrorHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + // Client should never send Error. + return 0, unix.EINVAL +} + +// MountHandler handles the Mount RPC. Note that there can not be concurrent +// executions of MountHandler on a connection because the connection enforces +// that Mount is the first message on the connection. Only after the connection +// has been successfully mounted can other channels be created. +func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req MountReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + mountPath := path.Clean(string(req.MountPath)) + if !filepath.IsAbs(mountPath) { + log.Warningf("mountPath %q is not absolute", mountPath) + return 0, unix.EINVAL + } + + if c.mounted { + log.Warningf("connection has already been mounted at %q", mountPath) + return 0, unix.EBUSY + } + + rootFD, rootIno, err := c.ServerImpl().Mount(c, mountPath) + if err != nil { + return 0, err + } + + c.server.addMountPoint(rootFD.FD()) + c.mounted = true + resp := MountResp{ + Root: rootIno, + SupportedMs: c.ServerImpl().SupportedMessages(), + MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()), + } + respPayloadLen := uint32(resp.SizeBytes()) + resp.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// ChannelHandler handles the Channel RPC. +func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize()) + if err != nil { + return 0, err + } + + // Start servicing the channel in a separate goroutine. + c.activeWg.Add(1) + go func() { + if err := c.service(ch); err != nil { + // Don't log shutdown error which is expected during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("lisafs.Connection.service(channel = @%p): %v", ch, err) + } + } + c.activeWg.Done() + }() + + clientDataFD, err := unix.Dup(desc.FD) + if err != nil { + unix.Close(fdSock) + ch.shutdown() + return 0, err + } + + // Respond to client with successful channel creation message. + if err := comm.DonateFD(clientDataFD); err != nil { + return 0, err + } + if err := comm.DonateFD(fdSock); err != nil { + return 0, err + } + resp := ChannelResp{ + dataOffset: desc.Offset, + dataLength: uint64(desc.Length), + } + respLen := uint32(resp.SizeBytes()) + resp.MarshalUnsafe(comm.PayloadBuf(respLen)) + return respLen, nil +} + +// FStatHandler handles the FStat RPC. +func FStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req StatReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.lookupFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + switch t := fd.(type) { + case *ControlFD: + return t.impl.Stat(c, comm) + case *OpenFD: + return t.impl.Stat(c, comm) + default: + panic(fmt.Sprintf("unknown fd type %T", t)) + } +} + +// SetStatHandler handles the SetStat RPC. +func SetStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + + var req SetStatReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + if req.Mask&^setStatSupportedMask != 0 { + return 0, unix.EPERM + } + + return fd.impl.SetStat(c, comm, req) +} + +// WalkHandler handles the Walk RPC. +func WalkHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req WalkReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + for _, name := range req.Path { + if err := checkSafeName(name); err != nil { + return 0, err + } + } + + return fd.impl.Walk(c, comm, req.Path) +} + +// WalkStatHandler handles the WalkStat RPC. +func WalkStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req WalkReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + // Note that this fd is allowed to not actually be a directory when the + // only path component to walk is "" (self). + if !fd.IsDir() { + if len(req.Path) > 1 || (len(req.Path) == 1 && len(req.Path[0]) > 0) { + return 0, unix.ENOTDIR + } + } + for i, name := range req.Path { + // First component is allowed to be "". + if i == 0 && len(name) == 0 { + continue + } + if err := checkSafeName(name); err != nil { + return 0, err + } + } + + return fd.impl.WalkStat(c, comm, req.Path) +} + +// OpenAtHandler handles the OpenAt RPC. +func OpenAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req OpenAtReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + // Only keep allowed open flags. + if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags { + log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags) + req.Flags = allowedFlags + } + + accessMode := req.Flags & unix.O_ACCMODE + trunc := req.Flags&unix.O_TRUNC != 0 + if c.readonly && (accessMode != unix.O_RDONLY || trunc) { + return 0, unix.EROFS + } + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if fd.IsDir() { + // Directory is not truncatable and must be opened with O_RDONLY. + if accessMode != unix.O_RDONLY || trunc { + return 0, unix.EISDIR + } + } + + return fd.impl.Open(c, comm, req.Flags) +} + +// OpenCreateAtHandler handles the OpenCreateAt RPC. +func OpenCreateAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req OpenCreateAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Only keep allowed open flags. + if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags { + log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags) + req.Flags = allowedFlags + } + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + + return fd.impl.OpenCreate(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Flags)) +} + +// CloseHandler handles the Close RPC. +func CloseHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req CloseReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + for _, fd := range req.FDs { + c.RemoveFD(fd) + } + + // There is no response message for this. + return 0, nil +} + +// FSyncHandler handles the FSync RPC. +func FSyncHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FsyncReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Return the first error we encounter, but sync everything we can + // regardless. + var retErr error + for _, fdid := range req.FDs { + if err := c.fsyncFD(fdid); err != nil && retErr == nil { + retErr = err + } + } + + // There is no response message for this. + return 0, retErr +} + +func (c *Connection) fsyncFD(id FDID) error { + fd, err := c.LookupOpenFD(id) + if err != nil { + return err + } + return fd.impl.Sync(c) +} + +// PWriteHandler handles the PWrite RPC. +func PWriteHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req PWriteReq + // Note that it is an optimized Unmarshal operation which avoids any buffer + // allocation and copying. req.Buf just points to payload. This is safe to do + // as the handler owns payload and req's lifetime is limited to the handler. + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + if !fd.writable { + return 0, unix.EBADF + } + return fd.impl.Write(c, comm, req.Buf, uint64(req.Offset)) +} + +// PReadHandler handles the PRead RPC. +func PReadHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req PReadReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.readable { + return 0, unix.EBADF + } + return fd.impl.Read(c, comm, req.Offset, req.Count) +} + +// MkdirAtHandler handles the MkdirAt RPC. +func MkdirAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req MkdirAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Mkdir(c, comm, req.Mode, req.UID, req.GID, name) +} + +// MknodAtHandler handles the MknodAt RPC. +func MknodAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req MknodAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Mknod(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Minor), uint32(req.Major)) +} + +// SymlinkAtHandler handles the SymlinkAt RPC. +func SymlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req SymlinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Symlink(c, comm, name, string(req.Target), req.UID, req.GID) +} + +// LinkAtHandler handles the LinkAt RPC. +func LinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req LinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + + targetFD, err := c.LookupControlFD(req.Target) + if err != nil { + return 0, err + } + return targetFD.impl.Link(c, comm, fd.impl, name) +} + +// FStatFSHandler handles the FStatFS RPC. +func FStatFSHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FStatFSReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.StatFS(c, comm) +} + +// FAllocateHandler handles the FAllocate RPC. +func FAllocateHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FAllocateReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.writable { + return 0, unix.EBADF + } + return 0, fd.impl.Allocate(c, req.Mode, req.Offset, req.Length) +} + +// ReadLinkAtHandler handles the ReadLinkAt RPC. +func ReadLinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req ReadLinkAtReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsSymlink() { + return 0, unix.EINVAL + } + return fd.impl.Readlink(c, comm) +} + +// FlushHandler handles the Flush RPC. +func FlushHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FlushReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + return 0, fd.impl.Flush(c) +} + +// ConnectHandler handles the Connect RPC. +func ConnectHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req ConnectReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsSocket() { + return 0, unix.ENOTSOCK + } + return 0, fd.impl.Connect(c, comm, req.SockType) +} + +// UnlinkAtHandler handles the UnlinkAt RPC. +func UnlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req UnlinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return 0, fd.impl.Unlink(c, name, uint32(req.Flags)) +} + +// RenameAtHandler handles the RenameAt RPC. +func RenameAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req RenameAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + newName := string(req.NewName) + if err := checkSafeName(newName); err != nil { + return 0, err + } + + renamed, err := c.LookupControlFD(req.Renamed) + if err != nil { + return 0, err + } + defer renamed.DecRef(nil) + + newDir, err := c.LookupControlFD(req.NewDir) + if err != nil { + return 0, err + } + defer newDir.DecRef(nil) + if !newDir.IsDir() { + return 0, unix.ENOTDIR + } + + // Hold RenameMu for writing during rename, this is important. + c.server.RenameMu.Lock() + defer c.server.RenameMu.Unlock() + + if renamed.parent == nil { + // renamed is root. + return 0, unix.EBUSY + } + + oldParentPath := renamed.parent.FilePathLocked() + oldPath := oldParentPath + "/" + renamed.name + if newName == renamed.name && oldParentPath == newDir.FilePathLocked() { + // Nothing to do. + return 0, nil + } + + updateControlFD, cleanUp, err := renamed.impl.RenameLocked(c, newDir.impl, newName) + if err != nil { + return 0, err + } + + c.server.forEachMountPoint(func(root *ControlFD) { + if !strings.HasPrefix(oldPath, root.name) { + return + } + pit := fspath.Parse(oldPath[len(root.name):]).Begin + root.renameRecursiveLocked(newDir, newName, pit, updateControlFD) + }) + + if cleanUp != nil { + cleanUp() + } + return 0, nil +} + +// Precondition: rename mutex must be locked for writing. +func (fd *ControlFD) renameRecursiveLocked(newDir *ControlFD, newName string, pit fspath.Iterator, updateControlFD func(ControlFDImpl)) { + if !pit.Ok() { + // fd should be renamed. + fd.clearParentLocked() + fd.setParentLocked(newDir) + fd.name = newName + if updateControlFD != nil { + updateControlFD(fd.impl) + } + return + } + + cur := pit.String() + next := pit.Next() + // No need to hold fd.childrenMu because RenameMu is locked for writing. + for child := fd.children.Front(); child != nil; child = child.Next() { + if child.name == cur { + child.renameRecursiveLocked(newDir, newName, next, updateControlFD) + } + } +} + +// Getdents64Handler handles the Getdents64 RPC. +func Getdents64Handler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req Getdents64Req + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.controlFD.IsDir() { + return 0, unix.ENOTDIR + } + + seek0 := false + if req.Count < 0 { + seek0 = true + req.Count = -req.Count + } + return fd.impl.Getdent64(c, comm, uint32(req.Count), seek0) +} + +// FGetXattrHandler handles the FGetXattr RPC. +func FGetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FGetXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.GetXattr(c, comm, string(req.Name), uint32(req.BufSize)) +} + +// FSetXattrHandler handles the FSetXattr RPC. +func FSetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FSetXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return 0, fd.impl.SetXattr(c, string(req.Name), string(req.Value), uint32(req.Flags)) +} + +// FListXattrHandler handles the FListXattr RPC. +func FListXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FListXattrReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.ListXattr(c, comm, req.Size) +} + +// FRemoveXattrHandler handles the FRemoveXattr RPC. +func FRemoveXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FRemoveXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return 0, fd.impl.RemoveXattr(c, comm, string(req.Name)) +} + +// checkSafeName validates the name and returns nil or returns an error. +func checkSafeName(name string) error { + if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." { + return nil + } + return unix.EINVAL +} diff --git a/pkg/lisafs/lisafs.go b/pkg/lisafs/lisafs.go new file mode 100644 index 000000000..4d8e956ab --- /dev/null +++ b/pkg/lisafs/lisafs.go @@ -0,0 +1,18 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package lisafs (LInux SAndbox FileSystem) defines the protocol for +// filesystem RPCs between an untrusted Sandbox (client) and a trusted +// filesystem server. +package lisafs diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go new file mode 100644 index 000000000..722afd0be --- /dev/null +++ b/pkg/lisafs/message.go @@ -0,0 +1,1251 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math" + "os" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// Messages have two parts: +// * A transport header used to decipher received messages. +// * A byte array referred to as "payload" which contains the actual message. +// +// "dataLen" refers to the size of both combined. + +// MID (message ID) is used to identify messages to parse from payload. +// +// +marshal slice:MIDSlice +type MID uint16 + +// These constants are used to identify their corresponding message types. +const ( + // Error is only used in responses to pass errors to client. + Error MID = 0 + + // Mount is used to establish connection between the client and server mount + // point. lisafs requires that the client makes a successful Mount RPC before + // making other RPCs. + Mount MID = 1 + + // Channel requests to start a new communicational channel. + Channel MID = 2 + + // FStat requests the stat(2) results for a specified file. + FStat MID = 3 + + // SetStat requests to change file attributes. Note that there is no one + // corresponding Linux syscall. This is a conglomeration of fchmod(2), + // fchown(2), ftruncate(2) and futimesat(2). + SetStat MID = 4 + + // Walk requests to walk the specified path starting from the specified + // directory. Server-side path traversal is terminated preemptively on + // symlinks entries because they can cause non-linear traversal. + Walk MID = 5 + + // WalkStat is the same as Walk, except the following differences: + // * If the first path component is "", then it also returns stat results + // for the directory where the walk starts. + // * Does not return Inode, just the Stat results for each path component. + WalkStat MID = 6 + + // OpenAt is analogous to openat(2). It does not perform any walk. It merely + // duplicates the control FD with the open flags passed. + OpenAt MID = 7 + + // OpenCreateAt is analogous to openat(2) with O_CREAT|O_EXCL added to flags. + // It also returns the newly created file inode. + OpenCreateAt MID = 8 + + // Close is analogous to close(2) but can work on multiple FDs. + Close MID = 9 + + // FSync is analogous to fsync(2) but can work on multiple FDs. + FSync MID = 10 + + // PWrite is analogous to pwrite(2). + PWrite MID = 11 + + // PRead is analogous to pread(2). + PRead MID = 12 + + // MkdirAt is analogous to mkdirat(2). + MkdirAt MID = 13 + + // MknodAt is analogous to mknodat(2). + MknodAt MID = 14 + + // SymlinkAt is analogous to symlinkat(2). + SymlinkAt MID = 15 + + // LinkAt is analogous to linkat(2). + LinkAt MID = 16 + + // FStatFS is analogous to fstatfs(2). + FStatFS MID = 17 + + // FAllocate is analogous to fallocate(2). + FAllocate MID = 18 + + // ReadLinkAt is analogous to readlinkat(2). + ReadLinkAt MID = 19 + + // Flush cleans up the file state. Its behavior is implementation + // dependent and might not even be supported in server implementations. + Flush MID = 20 + + // Connect is loosely analogous to connect(2). + Connect MID = 21 + + // UnlinkAt is analogous to unlinkat(2). + UnlinkAt MID = 22 + + // RenameAt is loosely analogous to renameat(2). + RenameAt MID = 23 + + // Getdents64 is analogous to getdents64(2). + Getdents64 MID = 24 + + // FGetXattr is analogous to fgetxattr(2). + FGetXattr MID = 25 + + // FSetXattr is analogous to fsetxattr(2). + FSetXattr MID = 26 + + // FListXattr is analogous to flistxattr(2). + FListXattr MID = 27 + + // FRemoveXattr is analogous to fremovexattr(2). + FRemoveXattr MID = 28 +) + +const ( + // NoUID is a sentinel used to indicate no valid UID. + NoUID UID = math.MaxUint32 + + // NoGID is a sentinel used to indicate no valid GID. + NoGID GID = math.MaxUint32 +) + +// MaxMessageSize is the recommended max message size that can be used by +// connections. Server implementations may choose to use other values. +func MaxMessageSize() uint32 { + // Return HugePageSize - PageSize so that when flipcall packet window is + // created with MaxMessageSize() + flipcall header size + channel header + // size, HugePageSize is allocated and can be backed by a single huge page + // if supported by the underlying memfd. + return uint32(hostarch.HugePageSize - os.Getpagesize()) +} + +// TODO(gvisor.dev/issue/6450): Once this is resolved: +// * Update manual implementations and function signatures. +// * Update RPC handlers and appropriate callers to handle errors correctly. +// * Update manual implementations to get rid of buffer shifting. + +// UID represents a user ID. +// +// +marshal +type UID uint32 + +// Ok returns true if uid is not NoUID. +func (uid UID) Ok() bool { + return uid != NoUID +} + +// GID represents a group ID. +// +// +marshal +type GID uint32 + +// Ok returns true if gid is not NoGID. +func (gid GID) Ok() bool { + return gid != NoGID +} + +// NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes. +func NoopMarshal([]byte) {} + +// NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes. +func NoopUnmarshal([]byte) {} + +// SizedString represents a string in memory. The marshalled string bytes are +// preceded by a uint32 signifying the string length. +type SizedString string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SizedString) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + len(*s) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SizedString) MarshalBytes(dst []byte) { + strLen := primitive.Uint32(len(*s)) + strLen.MarshalUnsafe(dst) + dst = dst[strLen.SizeBytes():] + // Copy without any allocation. + copy(dst[:strLen], *s) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SizedString) UnmarshalBytes(src []byte) { + var strLen primitive.Uint32 + strLen.UnmarshalUnsafe(src) + src = src[strLen.SizeBytes():] + // Take the hit, this leads to an allocation + memcpy. No way around it. + *s = SizedString(src[:strLen]) +} + +// StringArray represents an array of SizedStrings in memory. The marshalled +// array data is preceded by a uint32 signifying the array length. +type StringArray []string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *StringArray) SizeBytes() int { + size := (*primitive.Uint32)(nil).SizeBytes() + for _, str := range *s { + sstr := SizedString(str) + size += sstr.SizeBytes() + } + return size +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *StringArray) MarshalBytes(dst []byte) { + arrLen := primitive.Uint32(len(*s)) + arrLen.MarshalUnsafe(dst) + dst = dst[arrLen.SizeBytes():] + for _, str := range *s { + sstr := SizedString(str) + sstr.MarshalBytes(dst) + dst = dst[sstr.SizeBytes():] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *StringArray) UnmarshalBytes(src []byte) { + var arrLen primitive.Uint32 + arrLen.UnmarshalUnsafe(src) + src = src[arrLen.SizeBytes():] + + if cap(*s) < int(arrLen) { + *s = make([]string, arrLen) + } else { + *s = (*s)[:arrLen] + } + + for i := primitive.Uint32(0); i < arrLen; i++ { + var sstr SizedString + sstr.UnmarshalBytes(src) + src = src[sstr.SizeBytes():] + (*s)[i] = string(sstr) + } +} + +// Inode represents an inode on the remote filesystem. +// +// +marshal slice:InodeSlice +type Inode struct { + ControlFD FDID + _ uint32 // Need to make struct packed. + Stat linux.Statx +} + +// MountReq represents a Mount request. +type MountReq struct { + MountPath SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MountReq) SizeBytes() int { + return m.MountPath.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MountReq) MarshalBytes(dst []byte) { + m.MountPath.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MountReq) UnmarshalBytes(src []byte) { + m.MountPath.UnmarshalBytes(src) +} + +// MountResp represents a Mount response. +type MountResp struct { + Root Inode + // MaxMessageSize is the maximum size of messages communicated between the + // client and server in bytes. This includes the communication header. + MaxMessageSize primitive.Uint32 + // SupportedMs holds all the supported messages. + SupportedMs []MID +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MountResp) SizeBytes() int { + return m.Root.SizeBytes() + + m.MaxMessageSize.SizeBytes() + + (*primitive.Uint16)(nil).SizeBytes() + + (len(m.SupportedMs) * (*MID)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MountResp) MarshalBytes(dst []byte) { + m.Root.MarshalUnsafe(dst) + dst = dst[m.Root.SizeBytes():] + m.MaxMessageSize.MarshalUnsafe(dst) + dst = dst[m.MaxMessageSize.SizeBytes():] + numSupported := primitive.Uint16(len(m.SupportedMs)) + numSupported.MarshalBytes(dst) + dst = dst[numSupported.SizeBytes():] + MarshalUnsafeMIDSlice(m.SupportedMs, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MountResp) UnmarshalBytes(src []byte) { + m.Root.UnmarshalUnsafe(src) + src = src[m.Root.SizeBytes():] + m.MaxMessageSize.UnmarshalUnsafe(src) + src = src[m.MaxMessageSize.SizeBytes():] + var numSupported primitive.Uint16 + numSupported.UnmarshalBytes(src) + src = src[numSupported.SizeBytes():] + m.SupportedMs = make([]MID, numSupported) + UnmarshalUnsafeMIDSlice(m.SupportedMs, src) +} + +// ChannelResp is the response to the create channel request. +// +// +marshal +type ChannelResp struct { + dataOffset int64 + dataLength uint64 +} + +// ErrorResp is returned to represent an error while handling a request. +// +// +marshal +type ErrorResp struct { + errno uint32 +} + +// StatReq requests the stat results for the specified FD. +// +// +marshal +type StatReq struct { + FD FDID +} + +// SetStatReq is used to set attributeds on FDs. +// +// +marshal +type SetStatReq struct { + FD FDID + _ uint32 + Mask uint32 + Mode uint32 // Only permissions part is settable. + UID UID + GID GID + Size uint64 + Atime linux.Timespec + Mtime linux.Timespec +} + +// SetStatResp is used to communicate SetStat results. It contains a mask +// representing the failed changes. It also contains the errno of the failed +// set attribute operation. If multiple operations failed then any of those +// errnos can be returned. +// +// +marshal +type SetStatResp struct { + FailureMask uint32 + FailureErrNo uint32 +} + +// WalkReq is used to request to walk multiple path components at once. This +// is used for both Walk and WalkStat. +type WalkReq struct { + DirFD FDID + Path StringArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkReq) SizeBytes() int { + return w.DirFD.SizeBytes() + w.Path.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkReq) MarshalBytes(dst []byte) { + w.DirFD.MarshalUnsafe(dst) + dst = dst[w.DirFD.SizeBytes():] + w.Path.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkReq) UnmarshalBytes(src []byte) { + w.DirFD.UnmarshalUnsafe(src) + src = src[w.DirFD.SizeBytes():] + w.Path.UnmarshalBytes(src) +} + +// WalkStatus is used to indicate the reason for partial/unsuccessful server +// side Walk operations. Please note that partial/unsuccessful walk operations +// do not necessarily fail the RPC. The RPC is successful with a failure hint +// which can be used by the client to infer server-side state. +type WalkStatus = primitive.Uint8 + +const ( + // WalkSuccess indicates that all path components were successfully walked. + WalkSuccess WalkStatus = iota + + // WalkComponentDoesNotExist indicates that the walk was prematurely + // terminated because an intermediate path component does not exist on + // server. The results of all previous existing path components is returned. + WalkComponentDoesNotExist + + // WalkComponentSymlink indicates that the walk was prematurely + // terminated because an intermediate path component was a symlink. It is not + // safe to resolve symlinks remotely (unaware of mount points). + WalkComponentSymlink +) + +// WalkResp is used to communicate the inodes walked by the server. +type WalkResp struct { + Status WalkStatus + Inodes []Inode +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkResp) SizeBytes() int { + return w.Status.SizeBytes() + + (*primitive.Uint32)(nil).SizeBytes() + (len(w.Inodes) * (*Inode)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkResp) MarshalBytes(dst []byte) { + w.Status.MarshalUnsafe(dst) + dst = dst[w.Status.SizeBytes():] + + numInodes := primitive.Uint32(len(w.Inodes)) + numInodes.MarshalUnsafe(dst) + dst = dst[numInodes.SizeBytes():] + + MarshalUnsafeInodeSlice(w.Inodes, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkResp) UnmarshalBytes(src []byte) { + w.Status.UnmarshalUnsafe(src) + src = src[w.Status.SizeBytes():] + + var numInodes primitive.Uint32 + numInodes.UnmarshalUnsafe(src) + src = src[numInodes.SizeBytes():] + + if cap(w.Inodes) < int(numInodes) { + w.Inodes = make([]Inode, numInodes) + } else { + w.Inodes = w.Inodes[:numInodes] + } + UnmarshalUnsafeInodeSlice(w.Inodes, src) +} + +// WalkStatResp is used to communicate stat results for WalkStat. +type WalkStatResp struct { + Stats []linux.Statx +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkStatResp) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (len(w.Stats) * linux.SizeOfStatx) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkStatResp) MarshalBytes(dst []byte) { + numStats := primitive.Uint32(len(w.Stats)) + numStats.MarshalUnsafe(dst) + dst = dst[numStats.SizeBytes():] + + linux.MarshalUnsafeStatxSlice(w.Stats, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkStatResp) UnmarshalBytes(src []byte) { + var numStats primitive.Uint32 + numStats.UnmarshalUnsafe(src) + src = src[numStats.SizeBytes():] + + if cap(w.Stats) < int(numStats) { + w.Stats = make([]linux.Statx, numStats) + } else { + w.Stats = w.Stats[:numStats] + } + linux.UnmarshalUnsafeStatxSlice(w.Stats, src) +} + +// OpenAtReq is used to open existing FDs with the specified flags. +// +// +marshal +type OpenAtReq struct { + FD FDID + Flags uint32 +} + +// OpenAtResp is used to communicate the newly created FD. +// +// +marshal +type OpenAtResp struct { + NewFD FDID +} + +// +marshal +type createCommon struct { + DirFD FDID + Mode linux.FileMode + _ uint16 // Need to make struct packed. + UID UID + GID GID +} + +// OpenCreateAtReq is used to make OpenCreateAt requests. +type OpenCreateAtReq struct { + createCommon + Name SizedString + Flags primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (o *OpenCreateAtReq) SizeBytes() int { + return o.createCommon.SizeBytes() + o.Name.SizeBytes() + o.Flags.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (o *OpenCreateAtReq) MarshalBytes(dst []byte) { + o.createCommon.MarshalUnsafe(dst) + dst = dst[o.createCommon.SizeBytes():] + o.Name.MarshalBytes(dst) + dst = dst[o.Name.SizeBytes():] + o.Flags.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) { + o.createCommon.UnmarshalUnsafe(src) + src = src[o.createCommon.SizeBytes():] + o.Name.UnmarshalBytes(src) + src = src[o.Name.SizeBytes():] + o.Flags.UnmarshalUnsafe(src) +} + +// OpenCreateAtResp is used to communicate successful OpenCreateAt results. +// +// +marshal +type OpenCreateAtResp struct { + Child Inode + NewFD FDID + _ uint32 // Need to make struct packed. +} + +// FdArray is a utility struct which implements a marshallable type for +// communicating an array of FDIDs. In memory, the array data is preceded by a +// uint32 denoting the array length. +type FdArray []FDID + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FdArray) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (len(*f) * (*FDID)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FdArray) MarshalBytes(dst []byte) { + arrLen := primitive.Uint32(len(*f)) + arrLen.MarshalUnsafe(dst) + dst = dst[arrLen.SizeBytes():] + MarshalUnsafeFDIDSlice(*f, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FdArray) UnmarshalBytes(src []byte) { + var arrLen primitive.Uint32 + arrLen.UnmarshalUnsafe(src) + src = src[arrLen.SizeBytes():] + if cap(*f) < int(arrLen) { + *f = make(FdArray, arrLen) + } else { + *f = (*f)[:arrLen] + } + UnmarshalUnsafeFDIDSlice(*f, src) +} + +// CloseReq is used to close(2) FDs. +type CloseReq struct { + FDs FdArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (c *CloseReq) SizeBytes() int { + return c.FDs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *CloseReq) MarshalBytes(dst []byte) { + c.FDs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *CloseReq) UnmarshalBytes(src []byte) { + c.FDs.UnmarshalBytes(src) +} + +// FsyncReq is used to fsync(2) FDs. +type FsyncReq struct { + FDs FdArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FsyncReq) SizeBytes() int { + return f.FDs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FsyncReq) MarshalBytes(dst []byte) { + f.FDs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FsyncReq) UnmarshalBytes(src []byte) { + f.FDs.UnmarshalBytes(src) +} + +// PReadReq is used to pread(2) on an FD. +// +// +marshal +type PReadReq struct { + Offset uint64 + FD FDID + Count uint32 +} + +// PReadResp is used to return the result of pread(2). +type PReadResp struct { + NumBytes primitive.Uint32 + Buf []byte +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *PReadResp) SizeBytes() int { + return r.NumBytes.SizeBytes() + int(r.NumBytes) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *PReadResp) MarshalBytes(dst []byte) { + r.NumBytes.MarshalUnsafe(dst) + dst = dst[r.NumBytes.SizeBytes():] + copy(dst[:r.NumBytes], r.Buf[:r.NumBytes]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *PReadResp) UnmarshalBytes(src []byte) { + r.NumBytes.UnmarshalUnsafe(src) + src = src[r.NumBytes.SizeBytes():] + + // We expect the client to have already allocated r.Buf. r.Buf probably + // (optimally) points to usermem. Directly copy into that. + copy(r.Buf[:r.NumBytes], src[:r.NumBytes]) +} + +// PWriteReq is used to pwrite(2) on an FD. +type PWriteReq struct { + Offset primitive.Uint64 + FD FDID + NumBytes primitive.Uint32 + Buf []byte +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *PWriteReq) SizeBytes() int { + return w.Offset.SizeBytes() + w.FD.SizeBytes() + w.NumBytes.SizeBytes() + int(w.NumBytes) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *PWriteReq) MarshalBytes(dst []byte) { + w.Offset.MarshalUnsafe(dst) + dst = dst[w.Offset.SizeBytes():] + w.FD.MarshalUnsafe(dst) + dst = dst[w.FD.SizeBytes():] + w.NumBytes.MarshalUnsafe(dst) + dst = dst[w.NumBytes.SizeBytes():] + copy(dst[:w.NumBytes], w.Buf[:w.NumBytes]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *PWriteReq) UnmarshalBytes(src []byte) { + w.Offset.UnmarshalUnsafe(src) + src = src[w.Offset.SizeBytes():] + w.FD.UnmarshalUnsafe(src) + src = src[w.FD.SizeBytes():] + w.NumBytes.UnmarshalUnsafe(src) + src = src[w.NumBytes.SizeBytes():] + + // This is an optimization. Assuming that the server is making this call, it + // is safe to just point to src rather than allocating and copying. + w.Buf = src[:w.NumBytes] +} + +// PWriteResp is used to return the result of pwrite(2). +// +// +marshal +type PWriteResp struct { + Count uint64 +} + +// MkdirAtReq is used to make MkdirAt requests. +type MkdirAtReq struct { + createCommon + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MkdirAtReq) SizeBytes() int { + return m.createCommon.SizeBytes() + m.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MkdirAtReq) MarshalBytes(dst []byte) { + m.createCommon.MarshalUnsafe(dst) + dst = dst[m.createCommon.SizeBytes():] + m.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MkdirAtReq) UnmarshalBytes(src []byte) { + m.createCommon.UnmarshalUnsafe(src) + src = src[m.createCommon.SizeBytes():] + m.Name.UnmarshalBytes(src) +} + +// MkdirAtResp is the response to a successful MkdirAt request. +// +// +marshal +type MkdirAtResp struct { + ChildDir Inode +} + +// MknodAtReq is used to make MknodAt requests. +type MknodAtReq struct { + createCommon + Name SizedString + Minor primitive.Uint32 + Major primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MknodAtReq) SizeBytes() int { + return m.createCommon.SizeBytes() + m.Name.SizeBytes() + m.Minor.SizeBytes() + m.Major.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MknodAtReq) MarshalBytes(dst []byte) { + m.createCommon.MarshalUnsafe(dst) + dst = dst[m.createCommon.SizeBytes():] + m.Name.MarshalBytes(dst) + dst = dst[m.Name.SizeBytes():] + m.Minor.MarshalUnsafe(dst) + dst = dst[m.Minor.SizeBytes():] + m.Major.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MknodAtReq) UnmarshalBytes(src []byte) { + m.createCommon.UnmarshalUnsafe(src) + src = src[m.createCommon.SizeBytes():] + m.Name.UnmarshalBytes(src) + src = src[m.Name.SizeBytes():] + m.Minor.UnmarshalUnsafe(src) + src = src[m.Minor.SizeBytes():] + m.Major.UnmarshalUnsafe(src) +} + +// MknodAtResp is the response to a successful MknodAt request. +// +// +marshal +type MknodAtResp struct { + Child Inode +} + +// SymlinkAtReq is used to make SymlinkAt request. +type SymlinkAtReq struct { + DirFD FDID + Name SizedString + Target SizedString + UID UID + GID GID +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SymlinkAtReq) SizeBytes() int { + return s.DirFD.SizeBytes() + s.Name.SizeBytes() + s.Target.SizeBytes() + s.UID.SizeBytes() + s.GID.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SymlinkAtReq) MarshalBytes(dst []byte) { + s.DirFD.MarshalUnsafe(dst) + dst = dst[s.DirFD.SizeBytes():] + s.Name.MarshalBytes(dst) + dst = dst[s.Name.SizeBytes():] + s.Target.MarshalBytes(dst) + dst = dst[s.Target.SizeBytes():] + s.UID.MarshalUnsafe(dst) + dst = dst[s.UID.SizeBytes():] + s.GID.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SymlinkAtReq) UnmarshalBytes(src []byte) { + s.DirFD.UnmarshalUnsafe(src) + src = src[s.DirFD.SizeBytes():] + s.Name.UnmarshalBytes(src) + src = src[s.Name.SizeBytes():] + s.Target.UnmarshalBytes(src) + src = src[s.Target.SizeBytes():] + s.UID.UnmarshalUnsafe(src) + src = src[s.UID.SizeBytes():] + s.GID.UnmarshalUnsafe(src) +} + +// SymlinkAtResp is the response to a successful SymlinkAt request. +// +// +marshal +type SymlinkAtResp struct { + Symlink Inode +} + +// LinkAtReq is used to make LinkAt requests. +type LinkAtReq struct { + DirFD FDID + Target FDID + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (l *LinkAtReq) SizeBytes() int { + return l.DirFD.SizeBytes() + l.Target.SizeBytes() + l.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (l *LinkAtReq) MarshalBytes(dst []byte) { + l.DirFD.MarshalUnsafe(dst) + dst = dst[l.DirFD.SizeBytes():] + l.Target.MarshalUnsafe(dst) + dst = dst[l.Target.SizeBytes():] + l.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (l *LinkAtReq) UnmarshalBytes(src []byte) { + l.DirFD.UnmarshalUnsafe(src) + src = src[l.DirFD.SizeBytes():] + l.Target.UnmarshalUnsafe(src) + src = src[l.Target.SizeBytes():] + l.Name.UnmarshalBytes(src) +} + +// LinkAtResp is used to respond to a successful LinkAt request. +// +// +marshal +type LinkAtResp struct { + Link Inode +} + +// FStatFSReq is used to request StatFS results for the specified FD. +// +// +marshal +type FStatFSReq struct { + FD FDID +} + +// StatFS is responded to a successful FStatFS request. +// +// +marshal +type StatFS struct { + Type uint64 + BlockSize int64 + Blocks uint64 + BlocksFree uint64 + BlocksAvailable uint64 + Files uint64 + FilesFree uint64 + NameLength uint64 +} + +// FAllocateReq is used to request to fallocate(2) an FD. This has no response. +// +// +marshal +type FAllocateReq struct { + FD FDID + _ uint32 + Mode uint64 + Offset uint64 + Length uint64 +} + +// ReadLinkAtReq is used to readlinkat(2) at the specified FD. +// +// +marshal +type ReadLinkAtReq struct { + FD FDID +} + +// ReadLinkAtResp is used to communicate ReadLinkAt results. +type ReadLinkAtResp struct { + Target SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *ReadLinkAtResp) SizeBytes() int { + return r.Target.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *ReadLinkAtResp) MarshalBytes(dst []byte) { + r.Target.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) { + r.Target.UnmarshalBytes(src) +} + +// FlushReq is used to make Flush requests. +// +// +marshal +type FlushReq struct { + FD FDID +} + +// ConnectReq is used to make a Connect request. +// +// +marshal +type ConnectReq struct { + FD FDID + // SockType is used to specify the socket type to connect to. As a special + // case, SockType = 0 means that the socket type does not matter and the + // requester will accept any socket type. + SockType uint32 +} + +// UnlinkAtReq is used to make UnlinkAt request. +type UnlinkAtReq struct { + DirFD FDID + Name SizedString + Flags primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *UnlinkAtReq) SizeBytes() int { + return u.DirFD.SizeBytes() + u.Name.SizeBytes() + u.Flags.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *UnlinkAtReq) MarshalBytes(dst []byte) { + u.DirFD.MarshalUnsafe(dst) + dst = dst[u.DirFD.SizeBytes():] + u.Name.MarshalBytes(dst) + dst = dst[u.Name.SizeBytes():] + u.Flags.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *UnlinkAtReq) UnmarshalBytes(src []byte) { + u.DirFD.UnmarshalUnsafe(src) + src = src[u.DirFD.SizeBytes():] + u.Name.UnmarshalBytes(src) + src = src[u.Name.SizeBytes():] + u.Flags.UnmarshalUnsafe(src) +} + +// RenameAtReq is used to make Rename requests. Note that the request takes in +// the to-be-renamed file's FD instead of oldDir and oldName like renameat(2). +type RenameAtReq struct { + Renamed FDID + NewDir FDID + NewName SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *RenameAtReq) SizeBytes() int { + return r.Renamed.SizeBytes() + r.NewDir.SizeBytes() + r.NewName.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *RenameAtReq) MarshalBytes(dst []byte) { + r.Renamed.MarshalUnsafe(dst) + dst = dst[r.Renamed.SizeBytes():] + r.NewDir.MarshalUnsafe(dst) + dst = dst[r.NewDir.SizeBytes():] + r.NewName.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *RenameAtReq) UnmarshalBytes(src []byte) { + r.Renamed.UnmarshalUnsafe(src) + src = src[r.Renamed.SizeBytes():] + r.NewDir.UnmarshalUnsafe(src) + src = src[r.NewDir.SizeBytes():] + r.NewName.UnmarshalBytes(src) +} + +// Getdents64Req is used to make Getdents64 requests. +// +// +marshal +type Getdents64Req struct { + DirFD FDID + // Count is the number of bytes to read. A negative value of Count is used to + // indicate that the implementation must lseek(0, SEEK_SET) before calling + // getdents64(2). Implementations must use the absolute value of Count to + // determine the number of bytes to read. + Count int32 +} + +// Dirent64 is analogous to struct linux_dirent64. +type Dirent64 struct { + Ino primitive.Uint64 + DevMinor primitive.Uint32 + DevMajor primitive.Uint32 + Off primitive.Uint64 + Type primitive.Uint8 + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (d *Dirent64) SizeBytes() int { + return d.Ino.SizeBytes() + d.DevMinor.SizeBytes() + d.DevMajor.SizeBytes() + d.Off.SizeBytes() + d.Type.SizeBytes() + d.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (d *Dirent64) MarshalBytes(dst []byte) { + d.Ino.MarshalUnsafe(dst) + dst = dst[d.Ino.SizeBytes():] + d.DevMinor.MarshalUnsafe(dst) + dst = dst[d.DevMinor.SizeBytes():] + d.DevMajor.MarshalUnsafe(dst) + dst = dst[d.DevMajor.SizeBytes():] + d.Off.MarshalUnsafe(dst) + dst = dst[d.Off.SizeBytes():] + d.Type.MarshalUnsafe(dst) + dst = dst[d.Type.SizeBytes():] + d.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (d *Dirent64) UnmarshalBytes(src []byte) { + d.Ino.UnmarshalUnsafe(src) + src = src[d.Ino.SizeBytes():] + d.DevMinor.UnmarshalUnsafe(src) + src = src[d.DevMinor.SizeBytes():] + d.DevMajor.UnmarshalUnsafe(src) + src = src[d.DevMajor.SizeBytes():] + d.Off.UnmarshalUnsafe(src) + src = src[d.Off.SizeBytes():] + d.Type.UnmarshalUnsafe(src) + src = src[d.Type.SizeBytes():] + d.Name.UnmarshalBytes(src) +} + +// Getdents64Resp is used to communicate getdents64 results. +type Getdents64Resp struct { + Dirents []Dirent64 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *Getdents64Resp) SizeBytes() int { + ret := (*primitive.Uint32)(nil).SizeBytes() + for i := range g.Dirents { + ret += g.Dirents[i].SizeBytes() + } + return ret +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *Getdents64Resp) MarshalBytes(dst []byte) { + numDirents := primitive.Uint32(len(g.Dirents)) + numDirents.MarshalUnsafe(dst) + dst = dst[numDirents.SizeBytes():] + for i := range g.Dirents { + g.Dirents[i].MarshalBytes(dst) + dst = dst[g.Dirents[i].SizeBytes():] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *Getdents64Resp) UnmarshalBytes(src []byte) { + var numDirents primitive.Uint32 + numDirents.UnmarshalUnsafe(src) + if cap(g.Dirents) < int(numDirents) { + g.Dirents = make([]Dirent64, numDirents) + } else { + g.Dirents = g.Dirents[:numDirents] + } + + src = src[numDirents.SizeBytes():] + for i := range g.Dirents { + g.Dirents[i].UnmarshalBytes(src) + src = src[g.Dirents[i].SizeBytes():] + } +} + +// FGetXattrReq is used to make FGetXattr requests. The response to this is +// just a SizedString containing the xattr value. +type FGetXattrReq struct { + FD FDID + BufSize primitive.Uint32 + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *FGetXattrReq) SizeBytes() int { + return g.FD.SizeBytes() + g.BufSize.SizeBytes() + g.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *FGetXattrReq) MarshalBytes(dst []byte) { + g.FD.MarshalUnsafe(dst) + dst = dst[g.FD.SizeBytes():] + g.BufSize.MarshalUnsafe(dst) + dst = dst[g.BufSize.SizeBytes():] + g.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *FGetXattrReq) UnmarshalBytes(src []byte) { + g.FD.UnmarshalUnsafe(src) + src = src[g.FD.SizeBytes():] + g.BufSize.UnmarshalUnsafe(src) + src = src[g.BufSize.SizeBytes():] + g.Name.UnmarshalBytes(src) +} + +// FGetXattrResp is used to respond to FGetXattr request. +type FGetXattrResp struct { + Value SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *FGetXattrResp) SizeBytes() int { + return g.Value.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *FGetXattrResp) MarshalBytes(dst []byte) { + g.Value.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *FGetXattrResp) UnmarshalBytes(src []byte) { + g.Value.UnmarshalBytes(src) +} + +// FSetXattrReq is used to make FSetXattr requests. It has no response. +type FSetXattrReq struct { + FD FDID + Flags primitive.Uint32 + Name SizedString + Value SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *FSetXattrReq) SizeBytes() int { + return s.FD.SizeBytes() + s.Flags.SizeBytes() + s.Name.SizeBytes() + s.Value.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *FSetXattrReq) MarshalBytes(dst []byte) { + s.FD.MarshalUnsafe(dst) + dst = dst[s.FD.SizeBytes():] + s.Flags.MarshalUnsafe(dst) + dst = dst[s.Flags.SizeBytes():] + s.Name.MarshalBytes(dst) + dst = dst[s.Name.SizeBytes():] + s.Value.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *FSetXattrReq) UnmarshalBytes(src []byte) { + s.FD.UnmarshalUnsafe(src) + src = src[s.FD.SizeBytes():] + s.Flags.UnmarshalUnsafe(src) + src = src[s.Flags.SizeBytes():] + s.Name.UnmarshalBytes(src) + src = src[s.Name.SizeBytes():] + s.Value.UnmarshalBytes(src) +} + +// FRemoveXattrReq is used to make FRemoveXattr requests. It has no response. +type FRemoveXattrReq struct { + FD FDID + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *FRemoveXattrReq) SizeBytes() int { + return r.FD.SizeBytes() + r.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *FRemoveXattrReq) MarshalBytes(dst []byte) { + r.FD.MarshalUnsafe(dst) + dst = dst[r.FD.SizeBytes():] + r.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) { + r.FD.UnmarshalUnsafe(src) + src = src[r.FD.SizeBytes():] + r.Name.UnmarshalBytes(src) +} + +// FListXattrReq is used to make FListXattr requests. +// +// +marshal +type FListXattrReq struct { + FD FDID + _ uint32 + Size uint64 +} + +// FListXattrResp is used to respond to FListXattr requests. +type FListXattrResp struct { + Xattrs StringArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (l *FListXattrResp) SizeBytes() int { + return l.Xattrs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (l *FListXattrResp) MarshalBytes(dst []byte) { + l.Xattrs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (l *FListXattrResp) UnmarshalBytes(src []byte) { + l.Xattrs.UnmarshalBytes(src) +} diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go new file mode 100644 index 000000000..3868dfa08 --- /dev/null +++ b/pkg/lisafs/sample_message.go @@ -0,0 +1,110 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math/rand" + + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// MsgSimple is a sample packed struct which can be used to test message passing. +// +// +marshal slice:Msg1Slice +type MsgSimple struct { + A uint16 + B uint16 + C uint32 + D uint64 +} + +// Randomize randomizes the contents of m. +func (m *MsgSimple) Randomize() { + m.A = uint16(rand.Uint32()) + m.B = uint16(rand.Uint32()) + m.C = rand.Uint32() + m.D = rand.Uint64() +} + +// MsgDynamic is a sample dynamic struct which can be used to test message passing. +// +// +marshal dynamic +type MsgDynamic struct { + N primitive.Uint32 + Arr []MsgSimple +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MsgDynamic) SizeBytes() int { + return m.N.SizeBytes() + + (int(m.N) * (*MsgSimple)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MsgDynamic) MarshalBytes(dst []byte) { + m.N.MarshalUnsafe(dst) + dst = dst[m.N.SizeBytes():] + MarshalUnsafeMsg1Slice(m.Arr, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MsgDynamic) UnmarshalBytes(src []byte) { + m.N.UnmarshalUnsafe(src) + src = src[m.N.SizeBytes():] + m.Arr = make([]MsgSimple, m.N) + UnmarshalUnsafeMsg1Slice(m.Arr, src) +} + +// Randomize randomizes the contents of m. +func (m *MsgDynamic) Randomize(arrLen int) { + m.N = primitive.Uint32(arrLen) + m.Arr = make([]MsgSimple, arrLen) + for i := 0; i < arrLen; i++ { + m.Arr[i].Randomize() + } +} + +// P9Version mimics p9.TVersion and p9.Rversion. +// +// +marshal dynamic +type P9Version struct { + MSize primitive.Uint32 + Version string +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (v *P9Version) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (*primitive.Uint16)(nil).SizeBytes() + len(v.Version) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (v *P9Version) MarshalBytes(dst []byte) { + v.MSize.MarshalUnsafe(dst) + dst = dst[v.MSize.SizeBytes():] + versionLen := primitive.Uint16(len(v.Version)) + versionLen.MarshalUnsafe(dst) + dst = dst[versionLen.SizeBytes():] + copy(dst, v.Version) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (v *P9Version) UnmarshalBytes(src []byte) { + v.MSize.UnmarshalUnsafe(src) + src = src[v.MSize.SizeBytes():] + var versionLen primitive.Uint16 + versionLen.UnmarshalUnsafe(src) + src = src[versionLen.SizeBytes():] + v.Version = string(src[:versionLen]) +} diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go new file mode 100644 index 000000000..7515355ec --- /dev/null +++ b/pkg/lisafs/server.go @@ -0,0 +1,113 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "gvisor.dev/gvisor/pkg/sync" +) + +// Server serves a filesystem tree. Multiple connections on different mount +// points can be started on a server. The server provides utilities to safely +// modify the filesystem tree across its connections (mount points). Note that +// it does not support synchronizing filesystem tree mutations across other +// servers serving the same filesystem subtree. Server also manages the +// lifecycle of all connections. +type Server struct { + // connWg counts the number of active connections being tracked. + connWg sync.WaitGroup + + // RenameMu synchronizes rename operations within this filesystem tree. + RenameMu sync.RWMutex + + // handlers is a list of RPC handlers which can be indexed by the handler's + // corresponding MID. + handlers []RPCHandler + + // mountPoints keeps track of all the mount points this server serves. + mpMu sync.RWMutex + mountPoints []*ControlFD + + // impl is the server implementation which embeds this server. + impl ServerImpl +} + +// Init must be called before first use of server. +func (s *Server) Init(impl ServerImpl) { + s.impl = impl + s.handlers = handlers[:] +} + +// InitTestOnly is the same as Init except that it allows to swap out the +// underlying handlers with something custom. This is for test only. +func (s *Server) InitTestOnly(impl ServerImpl, handlers []RPCHandler) { + s.impl = impl + s.handlers = handlers +} + +// WithRenameReadLock invokes fn with the server's rename mutex locked for +// reading. This ensures that no rename operations occur concurrently. +func (s *Server) WithRenameReadLock(fn func() error) error { + s.RenameMu.RLock() + err := fn() + s.RenameMu.RUnlock() + return err +} + +// StartConnection starts the connection on a separate goroutine and tracks it. +func (s *Server) StartConnection(c *Connection) { + s.connWg.Add(1) + go func() { + c.Run() + s.connWg.Done() + }() +} + +// Wait waits for all connections started via StartConnection() to terminate. +func (s *Server) Wait() { + s.connWg.Wait() +} + +func (s *Server) addMountPoint(root *ControlFD) { + s.mpMu.Lock() + defer s.mpMu.Unlock() + s.mountPoints = append(s.mountPoints, root) +} + +func (s *Server) forEachMountPoint(fn func(root *ControlFD)) { + s.mpMu.RLock() + defer s.mpMu.RUnlock() + for _, mp := range s.mountPoints { + fn(mp) + } +} + +// ServerImpl contains the implementation details for a Server. +// Implementations of ServerImpl should contain their associated Server by +// value as their first field. +type ServerImpl interface { + // Mount is called when a Mount RPC is made. It mounts the connection at + // mountPath. + // + // Precondition: mountPath == path.Clean(mountPath). + Mount(c *Connection, mountPath string) (ControlFDImpl, Inode, error) + + // SupportedMessages returns a list of messages that the server + // implementation supports. + SupportedMessages() []MID + + // MaxMessageSize is the maximum payload length (in bytes) that can be sent + // to this server implementation. + MaxMessageSize() uint32 +} diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go new file mode 100644 index 000000000..88210242f --- /dev/null +++ b/pkg/lisafs/sock.go @@ -0,0 +1,208 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/unet" +) + +var ( + sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes()) +) + +// sockHeader is the header present in front of each message received on a UDS. +// +// +marshal +type sockHeader struct { + payloadLen uint32 + message MID + _ uint16 // Need to make struct packed. +} + +// sockCommunicator implements Communicator. This is not thread safe. +type sockCommunicator struct { + fdTracker + sock *unet.Socket + buf []byte +} + +var _ Communicator = (*sockCommunicator)(nil) + +func newSockComm(sock *unet.Socket) *sockCommunicator { + return &sockCommunicator{ + sock: sock, + buf: make([]byte, sockHeaderLen), + } +} + +func (s *sockCommunicator) FD() int { + return s.sock.FD() +} + +func (s *sockCommunicator) destroy() { + s.sock.Close() +} + +func (s *sockCommunicator) shutdown() { + if err := s.sock.Shutdown(); err != nil { + log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err) + } +} + +func (s *sockCommunicator) resizeBuf(size uint32) { + if cap(s.buf) < int(size) { + s.buf = s.buf[:cap(s.buf)] + s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...) + } else { + s.buf = s.buf[:size] + } +} + +// PayloadBuf implements Communicator.PayloadBuf. +func (s *sockCommunicator) PayloadBuf(size uint32) []byte { + s.resizeBuf(sockHeaderLen + size) + return s.buf[sockHeaderLen : sockHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil { + return 0, 0, err + } + + return s.rcvMsg(wantFDs) +} + +// sndPrepopulatedMsg assumes that s.buf has already been populated with +// `payloadLen` bytes of data. +func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error { + header := sockHeader{payloadLen: payloadLen, message: m} + header.MarshalUnsafe(s.buf) + dataLen := sockHeaderLen + payloadLen + return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds) +} + +// writeTo writes the passed iovec to the UDS and donates any passed FDs. +func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error { + w := sock.Writer(true) + if len(fds) > 0 { + w.PackFDs(fds...) + } + + fdsUnpacked := false + for n := 0; n < dataLen; { + cur, err := w.WriteVec(iovec) + if err != nil { + return err + } + n += cur + + // Fast common path. + if n >= dataLen { + break + } + + // Consume iovecs. + for consumed := 0; consumed < cur; { + if len(iovec[0]) <= cur-consumed { + consumed += len(iovec[0]) + iovec = iovec[1:] + } else { + iovec[0] = iovec[0][cur-consumed:] + break + } + } + + if n > 0 && !fdsUnpacked { + // Don't resend any control message. + fdsUnpacked = true + w.UnpackFDs() + } + } + return nil +} + +// rcvMsg reads the message header and payload from the UDS. It also populates +// fds with any donated FDs. +func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) { + fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs) + if err != nil { + return 0, 0, err + } + for _, fd := range fds { + s.TrackFD(fd) + } + + var header sockHeader + header.UnmarshalUnsafe(s.buf) + + // No payload? We are done. + if header.payloadLen == 0 { + return header.message, 0, nil + } + + if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil { + return 0, 0, err + } + + return header.message, header.payloadLen, nil +} + +// readFrom fills the passed buffer with data from the socket. It also returns +// any donated FDs. +func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) { + r := sock.Reader(true) + r.EnableFDs(int(wantFDs)) + + var ( + fds []int + fdInit bool + ) + n := len(buf) + for got := 0; got < n; { + cur, err := r.ReadVec([][]byte{buf[got:]}) + + // Ignore EOF if cur > 0. + if err != nil && (err != io.EOF || cur == 0) { + r.CloseFDs() + return nil, err + } + + if !fdInit && cur > 0 { + fds, err = r.ExtractFDs() + if err != nil { + return nil, err + } + + fdInit = true + r.EnableFDs(0) + } + + got += cur + } + return fds, nil +} + +func closeFDs(fds []int) { + for _, fd := range fds { + if fd >= 0 { + unix.Close(fd) + } + } +} diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go new file mode 100644 index 000000000..387f4b7a8 --- /dev/null +++ b/pkg/lisafs/sock_test.go @@ -0,0 +1,217 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) { + sock1, sock2, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer sock1.Close() + defer sock2.Close() + + var testWg sync.WaitGroup + testWg.Add(2) + + go func() { + fun1(newSockComm(sock1)) + testWg.Done() + }() + + go func() { + fun2(newSockComm(sock2)) + testWg.Done() + }() + + testWg.Wait() +} + +func TestReadWrite(t *testing.T) { + // Create random data to send. + n := 10000 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + runSocketTest(t, func(comm *sockCommunicator) { + // Scatter that data into two parts using Iovecs while sending. + mid := n / 2 + if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + if _, err := readFrom(comm.sock, gotData, 0); err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + + // Make sure we got the right data. + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + }) +} + +func TestFDDonation(t *testing.T) { + n := 10 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + // Try donating FDs to these files. + path1 := "/dev/null" + path2 := "/dev" + path3 := "/dev/random" + + runSocketTest(t, func(comm *sockCommunicator) { + devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0) + defer unix.Close(devNullFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path1, err) + } + devFD, err := unix.Open(path2, unix.O_RDONLY, 0) + defer unix.Close(devFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0) + defer unix.Close(devRandomFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + fds, err := readFrom(comm.sock, gotData, 3) + if err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + defer closeFDs(fds[:]) + + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + + if len(fds) != 3 { + t.Fatalf("wanted 3 FD, got %d", len(fds)) + } + + // Check that the FDs actually point to the correct file. + compareFDWithFile(t, fds[0], path1) + compareFDWithFile(t, fds[1], path2) + compareFDWithFile(t, fds[2], path3) + }) +} + +func compareFDWithFile(t *testing.T, fd int, path string) { + var want unix.Stat_t + if err := unix.Stat(path, &want); err != nil { + t.Fatalf("stat(%s) failed: %v", path, err) + } + + var got unix.Stat_t + if err := unix.Fstat(fd, &got); err != nil { + t.Fatalf("fstat on donated FD failed: %v", err) + } + + if got.Ino != want.Ino || got.Dev != want.Dev { + t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got) + } +} + +func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error { + var payloadLen uint32 + if msg != nil { + payloadLen = uint32(msg.SizeBytes()) + msg.MarshalUnsafe(comm.PayloadBuf(payloadLen)) + } + return comm.sndPrepopulatedMsg(m, payloadLen, nil) +} + +func TestSndRcvMessage(t *testing.T) { + req := &MsgSimple{} + req.Randomize() + reqM := MID(1) + + // Create a massive random response. + var resp MsgDynamic + resp.Randomize(100) + respM := MID(2) + + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, req); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, &resp) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, req) + if err := testSndMsg(comm, respM, &resp); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func TestSndRcvMessageNoPayload(t *testing.T) { + reqM := MID(1) + respM := MID(2) + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, nil) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, nil) + if err := testSndMsg(comm, respM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) { + gotM, payloadLen, err := comm.rcvMsg(0) + if err != nil { + t.Fatalf("readMessageFrom failed: %v", err) + } + if gotM != wantM { + t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM) + } + if wantMsg == nil { + if payloadLen != 0 { + t.Errorf("no payload expect but got %d bytes", payloadLen) + } + } else { + gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable) + gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + if !reflect.DeepEqual(wantMsg, gotMsg) { + t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg) + } + } +} diff --git a/pkg/lisafs/testsuite/BUILD b/pkg/lisafs/testsuite/BUILD new file mode 100644 index 000000000..b4a542b3a --- /dev/null +++ b/pkg/lisafs/testsuite/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_library( + name = "testsuite", + testonly = True, + srcs = ["testsuite.go"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/lisafs", + "//pkg/unet", + "@com_github_syndtr_gocapability//capability:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go new file mode 100644 index 000000000..5fc7c364d --- /dev/null +++ b/pkg/lisafs/testsuite/testsuite.go @@ -0,0 +1,637 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testsuite provides a integration testing suite for lisafs. +// These tests are intended for servers serving the local filesystem. +package testsuite + +import ( + "bytes" + "fmt" + "io/ioutil" + "math/rand" + "os" + "testing" + "time" + + "github.com/syndtr/gocapability/capability" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/lisafs" + "gvisor.dev/gvisor/pkg/unet" +) + +// Tester is the client code using this test suite. This interface abstracts +// away all the caller specific details. +type Tester interface { + // NewServer returns a new instance of the tester server. + NewServer(t *testing.T) *lisafs.Server + + // LinkSupported returns true if the backing server supports LinkAt. + LinkSupported() bool + + // SetUserGroupIDSupported returns true if the backing server supports + // changing UID/GID for files. + SetUserGroupIDSupported() bool +} + +// RunAllLocalFSTests runs all local FS tests as subtests. +func RunAllLocalFSTests(t *testing.T, tester Tester) { + for name, testFn := range localFSTests { + t.Run(name, func(t *testing.T) { + runServerClient(t, tester, testFn) + }) + } +} + +type testFunc func(context.Context, *testing.T, Tester, lisafs.ClientFD) + +var localFSTests map[string]testFunc = map[string]testFunc{ + "Stat": testStat, + "RegularFileIO": testRegularFileIO, + "RegularFileOpen": testRegularFileOpen, + "SetStat": testSetStat, + "Allocate": testAllocate, + "StatFS": testStatFS, + "Unlink": testUnlink, + "Symlink": testSymlink, + "HardLink": testHardLink, + "Walk": testWalk, + "Rename": testRename, + "Mknod": testMknod, + "Getdents": testGetdents, +} + +func runServerClient(t *testing.T, tester Tester, testFn testFunc) { + mountPath, err := ioutil.TempDir(os.Getenv("TEST_TMPDIR"), "") + if err != nil { + t.Fatalf("creation of temporary mountpoint failed: %v", err) + } + defer os.RemoveAll(mountPath) + + // fsgofer should run with a umask of 0, because we want to preserve file + // modes exactly for testing purposes. + unix.Umask(0) + + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + + server := tester.NewServer(t) + conn, err := server.CreateConnection(serverSocket, false /* readonly */) + if err != nil { + t.Fatalf("starting connection failed: %v", err) + return + } + server.StartConnection(conn) + + c, root, err := lisafs.NewClient(clientSocket, mountPath) + if err != nil { + t.Fatalf("client creation failed: %v", err) + } + + if !root.ControlFD.Ok() { + t.Fatalf("root control FD is not valid") + } + rootFile := c.NewFD(root.ControlFD) + + ctx := context.Background() + testFn(ctx, t, tester, rootFile) + closeFD(ctx, t, rootFile) + + c.Close() // This should trigger client and server shutdown. + server.Wait() +} + +func closeFD(ctx context.Context, t testing.TB, fdLisa lisafs.ClientFD) { + if err := fdLisa.Close(ctx); err != nil { + t.Errorf("failed to close FD: %v", err) + } +} + +func statTo(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, stat *linux.Statx) { + if err := fdLisa.StatTo(ctx, stat); err != nil { + t.Fatalf("stat failed: %v", err) + } +} + +func openCreateFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx, lisafs.ClientFD, int) { + child, childFD, childHostFD, err := fdLisa.OpenCreateAt(ctx, name, unix.O_RDWR, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("OpenCreateAt failed: %v", err) + } + if childHostFD == -1 { + t.Error("no host FD donated") + } + client := fdLisa.Client() + return client.NewFD(child.ControlFD), child.Stat, fdLisa.Client().NewFD(childFD), childHostFD +} + +func openFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, flags uint32, isReg bool) (lisafs.ClientFD, int) { + newFD, hostFD, err := fdLisa.OpenAt(ctx, flags) + if err != nil { + t.Fatalf("OpenAt failed: %v", err) + } + if hostFD == -1 && isReg { + t.Error("no host FD donated") + } + return fdLisa.Client().NewFD(newFD), hostFD +} + +func unlinkFile(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, isDir bool) { + var flags uint32 + if isDir { + flags = unix.AT_REMOVEDIR + } + if err := dir.UnlinkAt(ctx, name, flags); err != nil { + t.Errorf("unlinking file %s failed: %v", name, err) + } +} + +func symlink(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name, target string) (lisafs.ClientFD, linux.Statx) { + linkIno, err := dir.SymlinkAt(ctx, name, target, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("symlink failed: %v", err) + } + return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat +} + +func link(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, target lisafs.ClientFD) (lisafs.ClientFD, linux.Statx) { + linkIno, err := dir.LinkAt(ctx, target.ID(), name) + if err != nil { + t.Fatalf("link failed: %v", err) + } + return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat +} + +func mkdir(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) { + childIno, err := dir.MkdirAt(ctx, name, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("mkdir failed: %v", err) + } + return dir.Client().NewFD(childIno.ControlFD), childIno.Stat +} + +func mknod(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) { + nodeIno, err := dir.MknodAt(ctx, name, unix.S_IFREG|0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()), 0, 0) + if err != nil { + t.Fatalf("mknod failed: %v", err) + } + return dir.Client().NewFD(nodeIno.ControlFD), nodeIno.Stat +} + +func walk(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []lisafs.Inode { + _, inodes, err := dir.WalkMultiple(ctx, names) + if err != nil { + t.Fatalf("walk failed while trying to walk components %+v: %v", names, err) + } + return inodes +} + +func walkStat(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []linux.Statx { + stats, err := dir.WalkStat(ctx, names) + if err != nil { + t.Fatalf("walk failed while trying to walk components %+v: %v", names, err) + } + return stats +} + +func writeFD(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, buf []byte) error { + count, err := fdLisa.Write(ctx, buf, off) + if err != nil { + return err + } + if int(count) != len(buf) { + t.Errorf("partial write: buf size = %d, written = %d", len(buf), count) + } + return nil +} + +func readFDAndCmp(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, want []byte) { + buf := make([]byte, len(want)) + n, err := fdLisa.Read(ctx, buf, off) + if err != nil { + t.Errorf("read failed: %v", err) + return + } + if int(n) != len(want) { + t.Errorf("partial read: buf size = %d, read = %d", len(want), n) + return + } + if bytes.Compare(buf, want) != 0 { + t.Errorf("bytes read differ from what was expected: want = %v, got = %v", want, buf) + } +} + +func allocateAndVerify(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, length uint64) { + if err := fdLisa.Allocate(ctx, 0, off, length); err != nil { + t.Fatalf("fallocate failed: %v", err) + } + + var stat linux.Statx + statTo(ctx, t, fdLisa, &stat) + if want := off + length; stat.Size != want { + t.Errorf("incorrect file size after allocate: expected %d, got %d", off+length, stat.Size) + } +} + +func cmpStatx(t *testing.T, want, got linux.Statx) { + if got.Mask&unix.STATX_MODE != 0 && want.Mask&unix.STATX_MODE != 0 { + if got.Mode != want.Mode { + t.Errorf("mode differs: want %d, got %d", want.Mode, got.Mode) + } + } + if got.Mask&unix.STATX_INO != 0 && want.Mask&unix.STATX_INO != 0 { + if got.Ino != want.Ino { + t.Errorf("inode number differs: want %d, got %d", want.Ino, got.Ino) + } + } + if got.Mask&unix.STATX_NLINK != 0 && want.Mask&unix.STATX_NLINK != 0 { + if got.Nlink != want.Nlink { + t.Errorf("nlink differs: want %d, got %d", want.Nlink, got.Nlink) + } + } + if got.Mask&unix.STATX_UID != 0 && want.Mask&unix.STATX_UID != 0 { + if got.UID != want.UID { + t.Errorf("UID differs: want %d, got %d", want.UID, got.UID) + } + } + if got.Mask&unix.STATX_GID != 0 && want.Mask&unix.STATX_GID != 0 { + if got.GID != want.GID { + t.Errorf("GID differs: want %d, got %d", want.GID, got.GID) + } + } + if got.Mask&unix.STATX_SIZE != 0 && want.Mask&unix.STATX_SIZE != 0 { + if got.Size != want.Size { + t.Errorf("size differs: want %d, got %d", want.Size, got.Size) + } + } + if got.Mask&unix.STATX_BLOCKS != 0 && want.Mask&unix.STATX_BLOCKS != 0 { + if got.Blocks != want.Blocks { + t.Errorf("blocks differs: want %d, got %d", want.Blocks, got.Blocks) + } + } + if got.Mask&unix.STATX_ATIME != 0 && want.Mask&unix.STATX_ATIME != 0 { + if got.Atime != want.Atime { + t.Errorf("atime differs: want %d, got %d", want.Atime, got.Atime) + } + } + if got.Mask&unix.STATX_MTIME != 0 && want.Mask&unix.STATX_MTIME != 0 { + if got.Mtime != want.Mtime { + t.Errorf("mtime differs: want %d, got %d", want.Mtime, got.Mtime) + } + } + if got.Mask&unix.STATX_CTIME != 0 && want.Mask&unix.STATX_CTIME != 0 { + if got.Ctime != want.Ctime { + t.Errorf("ctime differs: want %d, got %d", want.Ctime, got.Ctime) + } + } +} + +func hasCapability(c capability.Cap) bool { + caps, err := capability.NewPid2(os.Getpid()) + if err != nil { + return false + } + if err := caps.Load(); err != nil { + return false + } + return caps.Get(capability.EFFECTIVE, c) +} + +func testStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + var rootStat linux.Statx + if err := root.StatTo(ctx, &rootStat); err != nil { + t.Errorf("stat on the root dir failed: %v", err) + } + + if ftype := rootStat.Mode & unix.S_IFMT; ftype != unix.S_IFDIR { + t.Errorf("root inode is not a directory, file type = %d", ftype) + } +} + +func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + // Test Read/Write RPCs with 2MB of data to test IO in chunks. + data := make([]byte, 1<<21) + rand.Read(data) + if err := writeFD(ctx, t, fd, 0, data); err != nil { + t.Fatalf("write failed: %v", err) + } + readFDAndCmp(ctx, t, fd, 0, data) + readFDAndCmp(ctx, t, fd, 50, data[50:]) + + // Make sure the host FD is configured properly. + hostReadData := make([]byte, len(data)) + if n, err := unix.Pread(hostFD, hostReadData, 0); err != nil { + t.Errorf("host read failed: %v", err) + } else if n != len(hostReadData) { + t.Errorf("partial read: buf size = %d, read = %d", len(hostReadData), n) + } else if bytes.Compare(hostReadData, data) != 0 { + t.Errorf("bytes read differ from what was expected: want = %v, got = %v", data, hostReadData) + } + + // Test syncing the writable FD. + if err := fd.Sync(ctx); err != nil { + t.Errorf("syncing the FD failed: %v", err) + } +} + +func testRegularFileOpen(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + // Open a readonly FD and try writing to it to get an EBADF. + roFile, roHostFD := openFile(ctx, t, controlFile, unix.O_RDONLY, true /* isReg */) + defer closeFD(ctx, t, roFile) + defer unix.Close(roHostFD) + if err := writeFD(ctx, t, roFile, 0, []byte{1, 2, 3}); err != unix.EBADF { + t.Errorf("writing to read only FD should generate EBADF, but got %v", err) + } +} + +func testSetStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + now := time.Now() + wantStat := linux.Statx{ + Mask: unix.STATX_MODE | unix.STATX_ATIME | unix.STATX_MTIME | unix.STATX_SIZE, + Mode: 0760, + UID: uint32(unix.Getuid()), + GID: uint32(unix.Getgid()), + Size: 50, + Atime: linux.NsecToStatxTimestamp(now.UnixNano()), + Mtime: linux.NsecToStatxTimestamp(now.UnixNano()), + } + if tester.SetUserGroupIDSupported() { + wantStat.Mask |= unix.STATX_UID | unix.STATX_GID + } + failureMask, failureErr, err := controlFile.SetStat(ctx, &wantStat) + if err != nil { + t.Fatalf("setstat failed: %v", err) + } + if failureMask != 0 { + t.Fatalf("some setstat operations failed: failureMask = %#b, failureErr = %v", failureMask, failureErr) + } + + // Verify that attributes were updated. + var gotStat linux.Statx + statTo(ctx, t, controlFile, &gotStat) + if gotStat.Mode&07777 != wantStat.Mode || + gotStat.Size != wantStat.Size || + gotStat.Atime.ToNsec() != wantStat.Atime.ToNsec() || + gotStat.Mtime.ToNsec() != wantStat.Mtime.ToNsec() || + (tester.SetUserGroupIDSupported() && (uint32(gotStat.UID) != wantStat.UID || uint32(gotStat.GID) != wantStat.GID)) { + t.Errorf("setStat did not update file correctly: setStat = %+v, stat = %+v", wantStat, gotStat) + } +} + +func testAllocate(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + allocateAndVerify(ctx, t, fd, 0, 40) + allocateAndVerify(ctx, t, fd, 20, 100) +} + +func testStatFS(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + var statFS lisafs.StatFS + if err := root.StatFSTo(ctx, &statFS); err != nil { + t.Errorf("statfs failed: %v", err) + } +} + +func testUnlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + unlinkFile(ctx, t, root, name, false /* isDir */) + if inodes := walk(ctx, t, root, []string{name}); len(inodes) > 0 { + t.Errorf("deleted file should not be generating inodes on walk: %+v", inodes) + } +} + +func testSymlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + target := "/tmp/some/path" + name := "symlinkFile" + link, linkStat := symlink(ctx, t, root, name, target) + defer closeFD(ctx, t, link) + + if linkStat.Mode&unix.S_IFMT != unix.S_IFLNK { + t.Errorf("stat return from symlink RPC indicates that the inode is not a symlink: mode = %d", linkStat.Mode) + } + + if gotTarget, err := link.ReadLinkAt(ctx); err != nil { + t.Fatalf("readlink failed: %v", err) + } else if gotTarget != target { + t.Errorf("readlink return incorrect target: expected %q, got %q", target, gotTarget) + } +} + +func testHardLink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + if !tester.LinkSupported() { + t.Skipf("server does not support LinkAt RPC") + } + if !hasCapability(capability.CAP_DAC_READ_SEARCH) { + t.Skipf("TestHardLink requires CAP_DAC_READ_SEARCH, running as %d", unix.Getuid()) + } + name := "tempFile" + controlFile, fileIno, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + link, linkStat := link(ctx, t, root, name, controlFile) + defer closeFD(ctx, t, link) + + if linkStat.Ino != fileIno.Ino { + t.Errorf("hard linked files have different inode numbers: %d %d", linkStat.Ino, fileIno.Ino) + } + if linkStat.DevMinor != fileIno.DevMinor { + t.Errorf("hard linked files have different minor device numbers: %d %d", linkStat.DevMinor, fileIno.DevMinor) + } + if linkStat.DevMajor != fileIno.DevMajor { + t.Errorf("hard linked files have different major device numbers: %d %d", linkStat.DevMajor, fileIno.DevMajor) + } +} + +func testWalk(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + // Create 10 nested directories. + n := 10 + curDir := root + + dirNames := make([]string, 0, n) + for i := 0; i < n; i++ { + name := fmt.Sprintf("tmpdir-%d", i) + childDir, _ := mkdir(ctx, t, curDir, name) + defer closeFD(ctx, t, childDir) + defer unlinkFile(ctx, t, curDir, name, true /* isDir */) + + curDir = childDir + dirNames = append(dirNames, name) + } + + // Walk all these directories. Add some junk at the end which should not be + // walked on. + dirNames = append(dirNames, []string{"a", "b", "c"}...) + inodes := walk(ctx, t, root, dirNames) + if len(inodes) != n { + t.Errorf("walk returned the incorrect number of inodes: wanted %d, got %d", n, len(inodes)) + } + + // Close all control FDs and collect stat results for all dirs including + // the root directory. + dirStats := make([]linux.Statx, 0, n+1) + var stat linux.Statx + statTo(ctx, t, root, &stat) + dirStats = append(dirStats, stat) + for _, inode := range inodes { + dirStats = append(dirStats, inode.Stat) + closeFD(ctx, t, root.Client().NewFD(inode.ControlFD)) + } + + // Test WalkStat which additonally returns Statx for root because the first + // path component is "". + dirNames = append([]string{""}, dirNames...) + gotStats := walkStat(ctx, t, root, dirNames) + if len(gotStats) != len(dirStats) { + t.Errorf("walkStat returned the incorrect number of statx: wanted %d, got %d", len(dirStats), len(gotStats)) + } else { + for i := range gotStats { + cmpStatx(t, dirStats[i], gotStats[i]) + } + } +} + +func testRename(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + tempFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, tempFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + tempDir, _ := mkdir(ctx, t, root, "tempDir") + defer closeFD(ctx, t, tempDir) + + // Move tempFile into tempDir. + if err := tempFile.RenameTo(ctx, tempDir.ID(), "movedFile"); err != nil { + t.Fatalf("rename failed: %v", err) + } + + inodes := walkStat(ctx, t, root, []string{"tempDir", "movedFile"}) + if len(inodes) != 2 { + t.Errorf("expected 2 files on walk but only found %d", len(inodes)) + } +} + +func testMknod(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "namedPipe" + pipeFile, pipeStat := mknod(ctx, t, root, name) + defer closeFD(ctx, t, pipeFile) + + var stat linux.Statx + statTo(ctx, t, pipeFile, &stat) + + if stat.Mode != pipeStat.Mode { + t.Errorf("mknod mode is incorrect: want %d, got %d", pipeStat.Mode, stat.Mode) + } + if stat.UID != pipeStat.UID { + t.Errorf("mknod UID is incorrect: want %d, got %d", pipeStat.UID, stat.UID) + } + if stat.GID != pipeStat.GID { + t.Errorf("mknod GID is incorrect: want %d, got %d", pipeStat.GID, stat.GID) + } +} + +func testGetdents(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + tempDir, _ := mkdir(ctx, t, root, "tempDir") + defer closeFD(ctx, t, tempDir) + defer unlinkFile(ctx, t, root, "tempDir", true /* isDir */) + + // Create 10 files in tempDir. + n := 10 + fileStats := make(map[string]linux.Statx) + for i := 0; i < n; i++ { + name := fmt.Sprintf("file-%d", i) + newFile, fileStat := mknod(ctx, t, tempDir, name) + defer closeFD(ctx, t, newFile) + defer unlinkFile(ctx, t, tempDir, name, false /* isDir */) + + fileStats[name] = fileStat + } + + // Use opened directory FD for getdents. + openDirFile, _ := openFile(ctx, t, tempDir, unix.O_RDONLY, false /* isReg */) + defer closeFD(ctx, t, openDirFile) + + dirents := make([]lisafs.Dirent64, 0, n) + for i := 0; i < n+2; i++ { + gotDirents, err := openDirFile.Getdents64(ctx, 40) + if err != nil { + t.Fatalf("getdents failed: %v", err) + } + if len(gotDirents) == 0 { + break + } + for _, dirent := range gotDirents { + if dirent.Name != "." && dirent.Name != ".." { + dirents = append(dirents, dirent) + } + } + } + + if len(dirents) != n { + t.Errorf("got incorrect number of dirents: wanted %d, got %d", n, len(dirents)) + } + for _, dirent := range dirents { + stat, ok := fileStats[string(dirent.Name)] + if !ok { + t.Errorf("received a dirent that was not created: %+v", dirent) + continue + } + + if dirent.Type != unix.DT_REG { + t.Errorf("dirent type of %s is incorrect: %d", dirent.Name, dirent.Type) + } + if uint64(dirent.Ino) != stat.Ino { + t.Errorf("dirent ino of %s is incorrect: want %d, got %d", dirent.Name, stat.Ino, dirent.Ino) + } + if uint32(dirent.DevMinor) != stat.DevMinor { + t.Errorf("dirent dev minor of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMinor, dirent.DevMinor) + } + if uint32(dirent.DevMajor) != stat.DevMajor { + t.Errorf("dirent dev major of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMajor, dirent.DevMajor) + } + } +} diff --git a/pkg/p9/client.go b/pkg/p9/client.go index eb496f02f..d618da820 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -115,7 +115,7 @@ type Client struct { // channels is the set of all initialized channels. channels []*channel - // availableChannels is a FIFO of inactive channels. + // availableChannels is a LIFO of inactive channels. availableChannels []*channel // -- below corresponds to sendRecvLegacy -- diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go index b6e2012e8..38ce9be1e 100644 --- a/pkg/ring0/defs.go +++ b/pkg/ring0/defs.go @@ -77,6 +77,9 @@ type CPU struct { // calls and exceptions via the Registers function. registers arch.Registers + // floatingPointState holds floating point state. + floatingPointState fpu.State + // hooks are kernel hooks. hooks Hooks } @@ -90,6 +93,15 @@ func (c *CPU) Registers() *arch.Registers { return &c.registers } +// FloatingPointState returns the kernel floating point state. +// +// This is explicitly safe to call during KernelException and KernelSyscall. +// +//go:nosplit +func (c *CPU) FloatingPointState() *fpu.State { + return &c.floatingPointState +} + // SwitchOpts are passed to the Switch function. type SwitchOpts struct { // Registers are the user register state. diff --git a/pkg/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go index 24f6e4cde..81e90dbf7 100644 --- a/pkg/ring0/defs_amd64.go +++ b/pkg/ring0/defs_amd64.go @@ -116,6 +116,11 @@ type CPUArchState struct { errorType uintptr *kernelEntry + + // Copies of global variables, stored in CPU so that they can be used by + // syscall and exception handlers (in the upper address space). + hasXSAVE bool + hasXSAVEOPT bool } // ErrorCode returns the last error code. diff --git a/pkg/ring0/entry_amd64.go b/pkg/ring0/entry_amd64.go index afd646b0b..13ad4e4df 100644 --- a/pkg/ring0/entry_amd64.go +++ b/pkg/ring0/entry_amd64.go @@ -39,11 +39,6 @@ func sysenter() // assembly to get the ABI0 (i.e., primary) address. func addrOfSysenter() uintptr -// swapgs swaps the current GS value. -// -// This must be called prior to sysret/iret. -func swapgs() - // jumpToKernel jumps to the kernel version of the current RIP. func jumpToKernel() diff --git a/pkg/ring0/entry_amd64.s b/pkg/ring0/entry_amd64.s index 520bd9f57..d2913f190 100644 --- a/pkg/ring0/entry_amd64.s +++ b/pkg/ring0/entry_amd64.s @@ -142,8 +142,103 @@ TEXT ·jumpToUser(SB),NOSPLIT,$0 MOVQ AX, 0(SP) RET +// See kernel_amd64.go. +// +// The 16-byte frame size is for the saved values of MXCSR and the x87 control +// word. +TEXT ·doSwitchToUser(SB),NOSPLIT,$16-48 + // We are passed pointers to heap objects, but do not store them in our + // local frame. + NO_LOCAL_POINTERS + + // MXCSR and the x87 control word are the only floating point state + // that is callee-save and thus we must save. + STMXCSR mxcsr-0(SP) + FSTCW cw-8(SP) + + // Restore application floating point state. + MOVQ cpu+0(FP), SI + MOVQ fpState+16(FP), DI + MOVB ·hasXSAVE(SB), BX + TESTB BX, BX + JZ no_xrstor + // Use xrstor to restore all available fp state. For now, we restore + // everything unconditionally by setting the implicit operand edx:eax + // (the "requested feature bitmap") to all 1's. + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f // XRSTOR64 0(DI) + JMP fprestore_done +no_xrstor: + // Fall back to fxrstor if xsave is not available. + FXRSTOR64 0(DI) +fprestore_done: + + // Set application GS. + MOVQ regs+8(FP), R8 + SWAP_GS() + MOVQ PTRACE_GS_BASE(R8), AX + PUSHQ AX + CALL ·writeGS(SB) + POPQ AX + + // Call sysret() or iret(). + MOVQ userCR3+24(FP), CX + MOVQ needIRET+32(FP), R9 + ADDQ $-32, SP + MOVQ SI, 0(SP) // cpu + MOVQ R8, 8(SP) // regs + MOVQ CX, 16(SP) // userCR3 + TESTQ R9, R9 + JNZ do_iret + CALL ·sysret(SB) + JMP done_sysret_or_iret +do_iret: + CALL ·iret(SB) +done_sysret_or_iret: + MOVQ 24(SP), AX // vector + ADDQ $32, SP + MOVQ AX, vector+40(FP) + + // Save application floating point state. + MOVQ fpState+16(FP), DI + MOVB ·hasXSAVE(SB), BX + MOVB ·hasXSAVEOPT(SB), CX + TESTB BX, BX + JZ no_xsave + // Use xsave/xsaveopt to save all extended state. + // We save everything unconditionally by setting RFBM to all 1's. + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + TESTB CX, CX + JZ no_xsaveopt + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI) + JMP fpsave_done +no_xsaveopt: + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI) + JMP fpsave_done +no_xsave: + FXSAVE64 0(DI) +fpsave_done: + + // Restore MXCSR and the x87 control word after one of the two floating + // point save cases above, to ensure the application versions are saved + // before being clobbered here. + LDMXCSR mxcsr-0(SP) + + // FLDCW is a "waiting" x87 instruction, meaning it checks for pending + // unmasked exceptions before executing. Thus if userspace has unmasked + // an exception and has one pending, it can be raised by FLDCW even + // though the new control word will mask exceptions. To prevent this, + // we must first clear pending exceptions (which will be restored by + // XRSTOR, et al). + BYTE $0xDB; BYTE $0xE2; // FNCLEX + FLDCW cw-8(SP) + + RET + // See entry_amd64.go. -TEXT ·sysret(SB),NOSPLIT,$0-24 +TEXT ·sysret(SB),NOSPLIT,$0-32 // Set application FS. We can't do this in Go because Go code needs FS. MOVQ regs+8(FP), AX MOVQ PTRACE_FS_BASE(AX), AX @@ -182,9 +277,11 @@ TEXT ·sysret(SB),NOSPLIT,$0-24 POPQ AX // Restore AX. POPQ SP // Restore SP. SYSRET64() + // sysenter or exception will write our return value and return to our + // caller. // See entry_amd64.go. -TEXT ·iret(SB),NOSPLIT,$0-24 +TEXT ·iret(SB),NOSPLIT,$0-32 // Set application FS. We can't do this in Go because Go code needs FS. MOVQ regs+8(FP), AX MOVQ PTRACE_FS_BASE(AX), AX @@ -220,6 +317,8 @@ TEXT ·iret(SB),NOSPLIT,$0-24 WRITE_CR3() // Switch to userCR3. POPQ AX // Restore AX. IRET() + // sysenter or exception will write our return value and return to our + // caller. // See entry_amd64.go. TEXT ·resume(SB),NOSPLIT,$0 @@ -324,11 +423,39 @@ kernel: MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. + // Save floating point state. CPU.floatingPointState is a slice, so the + // first word of CPU.floatingPointState is a pointer to the destination + // array. + MOVQ CPU_FPU_STATE(AX), DI + MOVB CPU_HAS_XSAVE(AX), BX + MOVB CPU_HAS_XSAVEOPT(AX), CX + TESTB BX, BX + JZ no_xsave + // Use xsave/xsaveopt to save all extended state. + // We save everything unconditionally by setting RFBM to all 1's. + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + TESTB CX, CX + JZ no_xsaveopt + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI) + JMP fpsave_done +no_xsaveopt: + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI) + JMP fpsave_done +no_xsave: + FXSAVE64 0(DI) +fpsave_done: + // Call the syscall trampoline. LOAD_KERNEL_STACK(GS) - PUSHQ AX // First argument (vCPU). - CALL ·kernelSyscall(SB) // Call the trampoline. - POPQ AX // Pop vCPU. + MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU. + PUSHQ AX // First argument (vCPU). + CALL ·kernelSyscall(SB) // Call the trampoline. + POPQ AX // Pop vCPU. + + // We only trigger a bluepill entry in the bluepill function, and can + // therefore be guaranteed that there is no floating point state to be + // loaded on resuming from halt. JMP ·resume(SB) ADDR_OF_FUNC(·addrOfSysenter(SB), ·sysenter(SB)); @@ -416,15 +543,43 @@ kernel: MOVQ 8(SP), BX // Load the error code. MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU. MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. - MOVQ 0(SP), BX // BX contains the vector. + + // Save floating point state. CPU.floatingPointState is a slice, so the + // first word of CPU.floatingPointState is a pointer to the destination + // array. + MOVQ CPU_FPU_STATE(AX), DI + MOVB CPU_HAS_XSAVE(AX), BX + MOVB CPU_HAS_XSAVEOPT(AX), CX + TESTB BX, BX + JZ no_xsave + // Use xsave/xsaveopt to save all extended state. + // We save everything unconditionally by setting RFBM to all 1's. + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + TESTB CX, CX + JZ no_xsaveopt + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI) + JMP fpsave_done +no_xsaveopt: + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI) + JMP fpsave_done +no_xsave: + FXSAVE64 0(DI) +fpsave_done: // Call the exception trampoline. + MOVQ 0(SP), BX // BX contains the vector. LOAD_KERNEL_STACK(GS) - PUSHQ BX // Second argument (vector). - PUSHQ AX // First argument (vCPU). - CALL ·kernelException(SB) // Call the trampoline. - POPQ BX // Pop vector. - POPQ AX // Pop vCPU. + MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU. + PUSHQ BX // Second argument (vector). + PUSHQ AX // First argument (vCPU). + CALL ·kernelException(SB) // Call the trampoline. + POPQ BX // Pop vector. + POPQ AX // Pop vCPU. + + // We only trigger a bluepill entry in the bluepill function, and can + // therefore be guaranteed that there is no floating point state to be + // loaded on resuming from halt. JMP ·resume(SB) #define EXCEPTION_WITH_ERROR(value, symbol, addr) \ diff --git a/pkg/ring0/kernel.go b/pkg/ring0/kernel.go index 292f9d0cc..e7dd84929 100644 --- a/pkg/ring0/kernel.go +++ b/pkg/ring0/kernel.go @@ -14,6 +14,10 @@ package ring0 +import ( + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" +) + // Init initializes a new kernel. // //go:nosplit @@ -80,6 +84,7 @@ func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) { c.self = c // Set self reference. c.kernel = k // Set kernel reference. c.init(cpuID) // Perform architectural init. + c.floatingPointState = fpu.NewState() // Require hooks. if hooks != nil { diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 4a4c0ae26..7e55011b5 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -143,6 +143,9 @@ func (c *CPU) init(cpuID int) { // Set mandatory flags. c.registers.Eflags = KernelFlagsSet + + c.hasXSAVE = hasXSAVE + c.hasXSAVEOPT = hasXSAVEOPT } // StackTop returns the kernel's stack address. @@ -248,19 +251,21 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Ss = uint64(Udata) // Ditto. // Perform the switch. - swapgs() // GS will be swapped on return. - WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. - LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point. + needIRET := uint64(0) if switchOpts.FullRestore { - vector = iret(c, regs, uintptr(userCR3)) - } else { - vector = sysret(c, regs, uintptr(userCR3)) + needIRET = 1 } - SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point. - RestoreKernelFPState() // escapes: no. Restore kernel MXCSR. + vector = doSwitchToUser(c, regs, switchOpts.FloatingPointState.BytePointer(), userCR3, needIRET) // escapes: no. return } +func doSwitchToUser( + cpu *CPU, // +0(FP) + regs *arch.Registers, // +8(FP) + fpState *byte, // +16(FP) + userCR3 uint64, // +24(FP) + needIRET uint64) Vector // +32(FP), +40(FP) + var ( sentryXCR0 uintptr sentryXCR0Once sync.Once @@ -287,7 +292,7 @@ func initSentryXCR0() { //go:nosplit func startGo(c *CPU) { // Save per-cpu. - WriteGS(kernelAddr(c.kernelEntry)) + writeGS(kernelAddr(c.kernelEntry)) // // TODO(mpratt): Note that per the note above, this should be done diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go index 05c394ff5..c42a5b205 100644 --- a/pkg/ring0/lib_amd64.go +++ b/pkg/ring0/lib_amd64.go @@ -21,29 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" ) -// LoadFloatingPoint loads floating point state by the most efficient mechanism -// available (set by Init). -var LoadFloatingPoint func(*byte) - -// SaveFloatingPoint saves floating point state by the most efficient mechanism -// available (set by Init). -var SaveFloatingPoint func(*byte) - -// fxrstor uses fxrstor64 to load floating point state. -func fxrstor(*byte) - -// xrstor uses xrstor to load floating point state. -func xrstor(*byte) - -// fxsave uses fxsave64 to save floating point state. -func fxsave(*byte) - -// xsave uses xsave to save floating point state. -func xsave(*byte) - -// xsaveopt uses xsaveopt to save floating point state. -func xsaveopt(*byte) - // writeFS sets the FS base address (selects one of wrfsbase or wrfsmsr). func writeFS(addr uintptr) @@ -53,8 +30,8 @@ func wrfsbase(addr uintptr) // wrfsmsr writes to the GS_BASE MSR. func wrfsmsr(addr uintptr) -// WriteGS sets the GS address (set by init). -var WriteGS func(addr uintptr) +// writeGS sets the GS address (selects one of wrgsbase or wrgsmsr). +func writeGS(addr uintptr) // wrgsbase writes to the GS base address. func wrgsbase(addr uintptr) @@ -106,19 +83,4 @@ func Init(featureSet *cpuid.FeatureSet) { hasXSAVE = featureSet.UseXsave() hasFSGSBASE = featureSet.HasFeature(cpuid.X86FeatureFSGSBase) validXCR0Mask = uintptr(featureSet.ValidXCR0Mask()) - if hasXSAVEOPT { - SaveFloatingPoint = xsaveopt - LoadFloatingPoint = xrstor - } else if hasXSAVE { - SaveFloatingPoint = xsave - LoadFloatingPoint = xrstor - } else { - SaveFloatingPoint = fxsave - LoadFloatingPoint = fxrstor - } - if hasFSGSBASE { - WriteGS = wrgsbase - } else { - WriteGS = wrgsmsr - } } diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s index 8ed98fc84..0f283aaae 100644 --- a/pkg/ring0/lib_amd64.s +++ b/pkg/ring0/lib_amd64.s @@ -128,6 +128,29 @@ TEXT ·wrfsmsr(SB),NOSPLIT,$0-8 BYTE $0x0f; BYTE $0x30; RET +// writeGS writes to the GS base. +// +// This is written in assembly because it must be callable from assembly (ABI0) +// without an intermediate transition to ABIInternal. +// +// Preconditions: must be running in the lower address space, as it accesses +// global data. +TEXT ·writeGS(SB),NOSPLIT,$8-8 + MOVQ addr+0(FP), AX + + CMPB ·hasFSGSBASE(SB), $1 + JNE msr + + PUSHQ AX + CALL ·wrgsbase(SB) + POPQ AX + RET +msr: + PUSHQ AX + CALL ·wrgsmsr(SB) + POPQ AX + RET + // wrgsbase writes to the GS base. // // The code corresponds to: diff --git a/pkg/ring0/offsets_amd64.go b/pkg/ring0/offsets_amd64.go index 75f6218b3..38fe27c35 100644 --- a/pkg/ring0/offsets_amd64.go +++ b/pkg/ring0/offsets_amd64.go @@ -35,6 +35,9 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_HAS_XSAVE 0x%02x\n", reflect.ValueOf(&c.hasXSAVE).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_HAS_XSAVEOPT 0x%02x\n", reflect.ValueOf(&c.hasXSAVEOPT).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_FPU_STATE 0x%02x\n", reflect.ValueOf(&c.floatingPointState).Pointer()-reflect.ValueOf(c).Pointer()) e := &kernelEntry{} fmt.Fprintf(w, "\n// CPU entry offsets.\n") diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 9dac53c80..3f17fba49 100644 --- a/pkg/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -322,12 +322,3 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc func (p *PageTables) MarkReadOnlyShared() { p.readOnlyShared = true } - -// PrefaultRootTable touches the root table page to be sure that its physical -// pages are mapped. -// -//go:nosplit -//go:noinline -func (p *PageTables) PrefaultRootTable() PTE { - return p.root[0] -} diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD index 0a045fc8e..2a1602e2b 100644 --- a/pkg/safecopy/BUILD +++ b/pkg/safecopy/BUILD @@ -18,9 +18,9 @@ go_library( ], visibility = ["//:sandbox"], deps = [ - "//pkg/abi/linux", "//pkg/errors", "//pkg/errors/linuxerr", + "//pkg/sighandling", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go index a9711e63d..0dd0aea83 100644 --- a/pkg/safecopy/safecopy.go +++ b/pkg/safecopy/safecopy.go @@ -23,6 +23,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/errors" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/sighandling" ) // SegvError is returned when a safecopy function receives SIGSEGV. @@ -132,10 +133,10 @@ func initializeAddresses() { func init() { initializeAddresses() - if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) } - if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) } linuxerr.AddErrorUnwrapper(func(e error) (*errors.Error, bool) { diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index 2365b2c0d..15f84abea 100644 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go @@ -20,7 +20,6 @@ import ( "unsafe" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" ) // maxRegisterSize is the maximum register size used in memcpy and memclr. It @@ -332,39 +331,3 @@ func errorFromFaultSignal(addr uintptr, sig int32) error { panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr)) } } - -// ReplaceSignalHandler replaces the existing signal handler for the provided -// signal with the one that handles faults in safecopy-protected functions. -// -// It stores the value of the previously set handler in previous. -// -// This function will be called on initialization in order to install safecopy -// handlers for appropriate signals. These handlers will call the previous -// handler however, and if this is function is being used externally then the -// same courtesy is expected. -func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { - var sa linux.SigAction - const maskLen = 8 - - // Get the existing signal handler information, and save the current - // handler. Once we replace it, we will use this pointer to fall back to - // it when we receive other signals. - if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { - return e - } - - // Fail if there isn't a previous handler. - if sa.Handler == 0 { - return fmt.Errorf("previous handler for signal %x isn't set", sig) - } - - *previous = uintptr(sa.Handler) - - // Install our own handler. - sa.Handler = uint64(handler) - if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { - return e - } - - return nil -} diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 2f3664c57..f721b7236 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -26,6 +26,23 @@ import ( "gvisor.dev/gvisor/pkg/urpc" ) +const ( + // DefaultBlockProfileRate is the default profiling rate for block + // profiles. + // + // The default here is 10%, which will record a stacktrace 10% of the + // time when blocking occurs. Since these events should not be super + // frequent, we expect this to achieve a reasonable balance between + // collecting the data we need and imposing a high performance cost + // (e.g. skewing even the CPU profile). + DefaultBlockProfileRate = 10 + + // DefaultMutexProfileRate is the default profiling rate for mutex + // profiles. Like the block rate above, we use a default rate of 10% + // for the same reasons. + DefaultMutexProfileRate = 10 +) + // Profile includes profile-related RPC stubs. It provides a way to // control the built-in runtime profiling facilities. // @@ -175,12 +192,8 @@ func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error { defer p.blockMu.Unlock() // Always set the rate. We then wait to collect a profile at this rate, - // and disable when we're done. Note that the default here is 10%, which - // will record a stacktrace 10% of the time when blocking occurs. Since - // these events should not be super frequent, we expect this to achieve - // a reasonable balance between collecting the data we need and imposing - // a high performance cost (e.g. skewing even the CPU profile). - rate := 10 + // and disable when we're done. + rate := DefaultBlockProfileRate if o.Rate != 0 { rate = o.Rate } @@ -220,9 +233,8 @@ func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error { p.mutexMu.Lock() defer p.mutexMu.Unlock() - // Always set the fraction. Like the block rate above, we use - // a default rate of 10% for the same reasons. - fraction := 10 + // Always set the fraction. + fraction := DefaultMutexProfileRate if o.Fraction != 0 { fraction = o.Fraction } diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 4370cce33..d2eb03bb7 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -45,7 +45,8 @@ type pipeOperations struct { fsutil.FileNoIoctl `state:"nosave"` fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` - waiter.Queue `state:"nosave"` + + waiter.Queue // flags are the flags used to open the pipe. flags fs.FileFlags `state:".(fs.FileFlags)"` diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 031cd33ce..a27dd0b9a 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -16,6 +16,7 @@ package fs import ( "io" + "math" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" @@ -360,10 +361,13 @@ func (*overlayFileOperations) ConfigureMMap(ctx context.Context, file *File, opt return linuxerr.ENODEV } - // FIXME(jamieliu): This is a copy/paste of fsutil.GenericConfigureMMap, - // which we can't use because the overlay implementation is in package fs, - // so depending on fs/fsutil would create a circular dependency. Move - // overlay to fs/overlay. + // TODO(gvisor.dev/issue/1624): This is a copy/paste of + // fsutil.GenericConfigureMMap, which we can't use because the overlay + // implementation is in package fs, so depending on fs/fsutil would create + // a circular dependency. VFS2 overlay doesn't have this issue. + if opts.Offset+opts.Length > math.MaxInt64 { + return linuxerr.EOVERFLOW + } opts.Mappable = o opts.MappingIdentity = file file.IncRef() diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 3ece73b81..38e3ed42d 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -16,6 +16,7 @@ package fsutil import ( "io" + "math" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" @@ -210,6 +211,9 @@ func (FileNoMMap) ConfigureMMap(context.Context, *fs.File, *memmap.MMapOpts) err // GenericConfigureMMap implements fs.FileOperations.ConfigureMMap for most // filesystems that support memory mapping. func GenericConfigureMMap(file *fs.File, m memmap.Mappable, opts *memmap.MMapOpts) error { + if opts.Offset+opts.Length > math.MaxInt64 { + return linuxerr.EOVERFLOW + } opts.Mappable = m opts.MappingIdentity = file file.IncRef() diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 23528bf25..37ddb1a3c 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -93,7 +93,8 @@ func NewHostFileMapper() *HostFileMapper { func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { f.refsMu.Lock() defer f.refsMu.Unlock() - for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize { + chunkStart := mr.Start &^ chunkMask + for { refs := f.refs[chunkStart] pgs := pagesInChunk(mr, chunkStart) if refs+pgs < refs { @@ -101,6 +102,10 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { panic(fmt.Sprintf("HostFileMapper.IncRefOn(%v): adding %d page references to chunk %#x, which has %d page references", mr, pgs, chunkStart, refs)) } f.refs[chunkStart] = refs + pgs + chunkStart += chunkSize + if chunkStart >= mr.End || chunkStart == 0 { + break + } } } @@ -112,7 +117,8 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { f.refsMu.Lock() defer f.refsMu.Unlock() - for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize { + chunkStart := mr.Start &^ chunkMask + for { refs := f.refs[chunkStart] pgs := pagesInChunk(mr, chunkStart) switch { @@ -128,6 +134,10 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { case refs < pgs: panic(fmt.Sprintf("HostFileMapper.DecRefOn(%v): removing %d page references from chunk %#x, which has %d page references", mr, pgs, chunkStart, refs)) } + chunkStart += chunkSize + if chunkStart >= mr.End || chunkStart == 0 { + break + } } } @@ -161,7 +171,8 @@ func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, if write { prot |= unix.PROT_WRITE } - for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize { + chunkStart := fr.Start &^ chunkMask + for { m, ok := f.mappings[chunkStart] if !ok { addr, _, errno := unix.Syscall6( @@ -201,6 +212,10 @@ func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, endOff = fr.End - chunkStart } fn(f.unsafeBlockFromChunkMapping(m.addr).TakeFirst64(endOff).DropFirst64(startOff)) + chunkStart += chunkSize + if chunkStart >= fr.End || chunkStart == 0 { + break + } } return nil } diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 92d58e3e9..99c37291e 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -70,7 +70,7 @@ type inodeFileState struct { descriptor *descriptor `state:"wait"` // Event queue for blocking operations. - queue waiter.Queue `state:"zerovalue"` + queue waiter.Queue // sattr is used to restore the inodeOperations. sattr fs.StableAttr `state:"wait"` diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index 51cd6cd37..941f37116 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -43,7 +43,7 @@ type Inotify struct { // user, since we may aggressively reuse an id on S/R. id uint64 - waiter.Queue `state:"nosave"` + waiter.Queue // evMu *only* protects the events list. We need a separate lock because // while queuing events, a watch needs to lock the event queue, and using mu diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go index 7d7a207cc..e39d340fe 100644 --- a/pkg/sentry/fs/lock/lock.go +++ b/pkg/sentry/fs/lock/lock.go @@ -132,7 +132,7 @@ type Locks struct { locks LockSet // blockedQueue is the queue of waiters that are waiting on a lock. - blockedQueue waiter.Queue `state:"zerovalue"` + blockedQueue waiter.Queue } // Blocker is the interface used for blocking locks. Passing a nil Blocker diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index 085aa6d61..443b9a94c 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -109,6 +109,9 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode "shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))), "shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))), "shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))), + "msgmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNI, 10))), + "msgmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMAX, 10))), + "msgmnb": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNB, 10))), } d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index 1c8518d71..ca8be8683 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -43,7 +43,7 @@ type TimerOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` - events waiter.Queue `state:"zerovalue"` + events waiter.Queue timer *ktime.Timer // val is the number of timer expirations since the last successful call to diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index f9fca6d8e..f2c9e9668 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -102,10 +102,10 @@ type lineDiscipline struct { column int // masterWaiter is used to wait on the master end of the TTY. - masterWaiter waiter.Queue `state:"zerovalue"` + masterWaiter waiter.Queue // replicaWaiter is used to wait on the replica end of the TTY. - replicaWaiter waiter.Queue `state:"zerovalue"` + replicaWaiter waiter.Queue } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD index e5fdcc776..60ee5ede2 100644 --- a/pkg/sentry/fsimpl/cgroupfs/BUILD +++ b/pkg/sentry/fsimpl/cgroupfs/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) @@ -18,6 +18,7 @@ go_library( name = "cgroupfs", srcs = [ "base.go", + "bitmap.go", "cgroupfs.go", "cpu.go", "cpuacct.go", @@ -29,10 +30,12 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/bitmap", "//pkg/context", "//pkg/coverage", "//pkg/errors/linuxerr", "//pkg/fspath", + "//pkg/hostarch", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", @@ -47,3 +50,11 @@ go_library( "//pkg/usermem", ], ) + +go_test( + name = "cgroupfs_test", + size = "small", + srcs = ["bitmap_test.go"], + library = ":cgroupfs", + deps = ["//pkg/bitmap"], +) diff --git a/pkg/sentry/fsimpl/cgroupfs/bitmap.go b/pkg/sentry/fsimpl/cgroupfs/bitmap.go new file mode 100644 index 000000000..8074641db --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/bitmap.go @@ -0,0 +1,139 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "fmt" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/bitmap" +) + +// formatBitmap produces a string representation of b, which lists the indicies +// of set bits in the bitmap. Indicies are separated by commas and ranges of +// set bits are abbreviated. Example outputs: "0,2,4", "0,3-7,10", "0-10". +// +// Inverse of parseBitmap. +func formatBitmap(b *bitmap.Bitmap) string { + ones := b.ToSlice() + if len(ones) == 0 { + return "" + } + + elems := make([]string, 0, len(ones)) + runStart := ones[0] + lastVal := ones[0] + inRun := false + + for _, v := range ones[1:] { + last := lastVal + lastVal = v + + if last+1 == v { + // In a contiguous block of ones. + if !inRun { + runStart = last + inRun = true + } + + continue + } + + // Non-contiguous bit. + if inRun { + // Render a run + elems = append(elems, fmt.Sprintf("%d-%d", runStart, last)) + inRun = false + continue + } + + // Lone non-contiguous bit. + elems = append(elems, fmt.Sprintf("%d", last)) + + } + + // Process potential final run + if inRun { + elems = append(elems, fmt.Sprintf("%d-%d", runStart, lastVal)) + } else { + elems = append(elems, fmt.Sprintf("%d", lastVal)) + } + + return strings.Join(elems, ",") +} + +func parseToken(token string) (start, end uint32, err error) { + ts := strings.SplitN(token, "-", 2) + switch len(ts) { + case 0: + return 0, 0, fmt.Errorf("invalid token %q", token) + case 1: + val, err := strconv.ParseUint(ts[0], 10, 32) + if err != nil { + return 0, 0, err + } + return uint32(val), uint32(val), nil + case 2: + val1, err := strconv.ParseUint(ts[0], 10, 32) + if err != nil { + return 0, 0, err + } + val2, err := strconv.ParseUint(ts[1], 10, 32) + if err != nil { + return 0, 0, err + } + if val1 >= val2 { + return 0, 0, fmt.Errorf("start (%v) must be less than end (%v)", val1, val2) + } + return uint32(val1), uint32(val2), nil + default: + panic(fmt.Sprintf("Unreachable: got %d substrs", len(ts))) + } +} + +// parseBitmap parses input as a bitmap. input should be a comma separated list +// of indices, and ranges of set bits may be abbreviated. Examples: "0,2,4", +// "0,3-7,10", "0-10". Input after the first newline or null byte is discarded. +// +// sizeHint sets the initial size of the bitmap, which may prevent reallocation +// when growing the bitmap during parsing. Ideally sizeHint should be at least +// as large as the bitmap represented by input, but this is not required. +// +// Inverse of formatBitmap. +func parseBitmap(input string, sizeHint uint32) (*bitmap.Bitmap, error) { + b := bitmap.New(sizeHint) + + if termIdx := strings.IndexAny(input, "\n\000"); termIdx != -1 { + input = input[:termIdx] + } + input = strings.TrimSpace(input) + + if len(input) == 0 { + return &b, nil + } + tokens := strings.Split(input, ",") + + for _, t := range tokens { + start, end, err := parseToken(strings.TrimSpace(t)) + if err != nil { + return nil, err + } + for i := start; i <= end; i++ { + b.Add(i) + } + } + return &b, nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/bitmap_test.go b/pkg/sentry/fsimpl/cgroupfs/bitmap_test.go new file mode 100644 index 000000000..5cc56de3b --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/bitmap_test.go @@ -0,0 +1,99 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "fmt" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/bitmap" +) + +func TestFormat(t *testing.T) { + tests := []struct { + input []uint32 + output string + }{ + {[]uint32{1, 2, 3, 4, 7}, "1-4,7"}, + {[]uint32{2}, "2"}, + {[]uint32{0, 1, 2}, "0-2"}, + {[]uint32{}, ""}, + {[]uint32{1, 3, 4, 5, 6, 9, 11, 13, 14, 15, 16, 17}, "1,3-6,9,11,13-17"}, + {[]uint32{2, 3, 10, 12, 13, 14, 15, 16, 20, 21, 33, 34, 47}, "2-3,10,12-16,20-21,33-34,47"}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) { + b := bitmap.New(64) + for _, v := range tt.input { + b.Add(v) + } + s := formatBitmap(&b) + if s != tt.output { + t.Errorf("Expected %q, got %q", tt.output, s) + } + b1, err := parseBitmap(s, 64) + if err != nil { + t.Fatalf("Failed to parse formatted bitmap: %v", err) + } + if got, want := b1.ToSlice(), b.ToSlice(); !reflect.DeepEqual(got, want) { + t.Errorf("Parsing formatted output doesn't result in the original bitmap. Got %v, want %v", got, want) + } + }) + } +} + +func TestParse(t *testing.T) { + tests := []struct { + input string + output []uint32 + shouldFail bool + }{ + {"1", []uint32{1}, false}, + {"", []uint32{}, false}, + {"1,2,3,4", []uint32{1, 2, 3, 4}, false}, + {"1-4", []uint32{1, 2, 3, 4}, false}, + {"1,2-4", []uint32{1, 2, 3, 4}, false}, + {"1,2-3,4", []uint32{1, 2, 3, 4}, false}, + {"1-2,3,4,10,11", []uint32{1, 2, 3, 4, 10, 11}, false}, + {"1,2-4,5,16", []uint32{1, 2, 3, 4, 5, 16}, false}, + {"abc", []uint32{}, true}, + {"1,3-2,4", []uint32{}, true}, + {"1,3-3,4", []uint32{}, true}, + {"1,2,3\000,4", []uint32{1, 2, 3}, false}, + {"1,2,3\n,4", []uint32{1, 2, 3}, false}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) { + b, err := parseBitmap(tt.input, 64) + if tt.shouldFail { + if err == nil { + t.Fatalf("Expected parsing of %q to fail, but it didn't", tt.input) + } + return + } + if err != nil { + t.Fatalf("Failed to parse bitmap: %v", err) + return + } + + got := b.ToSlice() + if !reflect.DeepEqual(got, tt.output) { + t.Errorf("Parsed bitmap doesn't match what we expected. Got %v, want %v", got, tt.output) + } + + }) + } +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go index edc3b50b9..e089b2c28 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -269,7 +269,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt case controllerCPUAcct: c = newCPUAcctController(fs) case controllerCPUSet: - c = newCPUSetController(fs) + c = newCPUSetController(k, fs) case controllerJob: c = newJobController(fs) case controllerMemory: diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go index ac547f8e2..62e7029da 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cpuset.go +++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go @@ -15,25 +15,133 @@ package cgroupfs import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/bitmap" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/usermem" ) // +stateify savable type cpusetController struct { controllerCommon + + maxCpus uint32 + maxMems uint32 + + cpus *bitmap.Bitmap + mems *bitmap.Bitmap } var _ controller = (*cpusetController)(nil) -func newCPUSetController(fs *filesystem) *cpusetController { - c := &cpusetController{} +func newCPUSetController(k *kernel.Kernel, fs *filesystem) *cpusetController { + cores := uint32(k.ApplicationCores()) + cpus := bitmap.New(cores) + cpus.FlipRange(0, cores) + mems := bitmap.New(1) + mems.FlipRange(0, 1) + c := &cpusetController{ + cpus: &cpus, + mems: &mems, + maxCpus: uint32(k.ApplicationCores()), + maxMems: 1, // We always report a single NUMA node. + } c.controllerCommon.init(controllerCPUSet, fs) return c } // AddControlFiles implements controller.AddControlFiles. func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { - // This controller is currently intentionally empty. + contents["cpuset.cpus"] = c.fs.newControllerWritableFile(ctx, creds, &cpusData{c: c}) + contents["cpuset.mems"] = c.fs.newControllerWritableFile(ctx, creds, &memsData{c: c}) +} + +// +stateify savable +type cpusData struct { + c *cpusetController +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpusData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%s\n", formatBitmap(d.c.cpus)) + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *cpusData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + src = src.DropFirst64(offset) + if src.NumBytes() > hostarch.PageSize { + return 0, linuxerr.EINVAL + } + + t := kernel.TaskFromContext(ctx) + buf := t.CopyScratchBuffer(hostarch.PageSize) + n, err := src.CopyIn(ctx, buf) + if err != nil { + return 0, err + } + buf = buf[:n] + + b, err := parseBitmap(string(buf), d.c.maxCpus) + if err != nil { + log.Warningf("cgroupfs cpuset controller: Failed to parse bitmap: %v", err) + return 0, linuxerr.EINVAL + } + + if got, want := b.Maximum(), d.c.maxCpus; got > want { + log.Warningf("cgroupfs cpuset controller: Attempted to specify cpuset.cpus beyond highest available cpu: got %d, want %d", got, want) + return 0, linuxerr.EINVAL + } + + d.c.cpus = b + return int64(n), nil +} + +// +stateify savable +type memsData struct { + c *cpusetController +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *memsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%s\n", formatBitmap(d.c.mems)) + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *memsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + src = src.DropFirst64(offset) + if src.NumBytes() > hostarch.PageSize { + return 0, linuxerr.EINVAL + } + + t := kernel.TaskFromContext(ctx) + buf := t.CopyScratchBuffer(hostarch.PageSize) + n, err := src.CopyIn(ctx, buf) + if err != nil { + return 0, err + } + buf = buf[:n] + + b, err := parseBitmap(string(buf), d.c.maxMems) + if err != nil { + log.Warningf("cgroupfs cpuset controller: Failed to parse bitmap: %v", err) + return 0, linuxerr.EINVAL + } + + if got, want := b.Maximum(), d.c.maxMems; got > want { + log.Warningf("cgroupfs cpuset controller: Attempted to specify cpuset.mems beyond highest available node: got %d, want %d", got, want) + return 0, linuxerr.EINVAL + } + + d.c.mems = b + return int64(n), nil } diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 4244f2cf5..509dd0e1a 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -54,7 +54,10 @@ go_library( "//pkg/fdnotifier", "//pkg/fspath", "//pkg/hostarch", + "//pkg/lisafs", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/p9", "//pkg/refs", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 5c48a9fee..d99a6112c 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -222,47 +222,88 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { off := uint64(0) const count = 64 * 1024 // for consistency with the vfs1 client d.handleMu.RLock() - if d.readFile.isNil() { + if !d.isReadFileOk() { // This should not be possible because a readable handle should // have been opened when the calling directoryFD was opened. d.handleMu.RUnlock() panic("gofer.dentry.getDirents called without a readable handle") } + // shouldSeek0 indicates whether the server should SEEK to 0 before reading + // directory entries. + shouldSeek0 := true for { - p9ds, err := d.readFile.readdir(ctx, off, count) - if err != nil { - d.handleMu.RUnlock() - return nil, err - } - if len(p9ds) == 0 { - d.handleMu.RUnlock() - break - } - for _, p9d := range p9ds { - if p9d.Name == "." || p9d.Name == ".." { - continue + if d.fs.opts.lisaEnabled { + countLisa := int32(count) + if shouldSeek0 { + // See lisafs.Getdents64Req.Count. + countLisa = -countLisa + shouldSeek0 = false + } + lisafsDs, err := d.readFDLisa.Getdents64(ctx, countLisa) + if err != nil { + d.handleMu.RUnlock() + return nil, err + } + if len(lisafsDs) == 0 { + d.handleMu.RUnlock() + break + } + for i := range lisafsDs { + name := string(lisafsDs[i].Name) + if name == "." || name == ".." { + continue + } + dirent := vfs.Dirent{ + Name: name, + Ino: d.fs.inoFromKey(inoKey{ + ino: uint64(lisafsDs[i].Ino), + devMinor: uint32(lisafsDs[i].DevMinor), + devMajor: uint32(lisafsDs[i].DevMajor), + }), + NextOff: int64(len(dirents) + 1), + Type: uint8(lisafsDs[i].Type), + } + dirents = append(dirents, dirent) + if realChildren != nil { + realChildren[name] = struct{}{} + } } - dirent := vfs.Dirent{ - Name: p9d.Name, - Ino: d.fs.inoFromQIDPath(p9d.QID.Path), - NextOff: int64(len(dirents) + 1), + } else { + p9ds, err := d.readFile.readdir(ctx, off, count) + if err != nil { + d.handleMu.RUnlock() + return nil, err } - // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or - // DMSOCKET. - switch p9d.Type { - case p9.TypeSymlink: - dirent.Type = linux.DT_LNK - case p9.TypeDir: - dirent.Type = linux.DT_DIR - default: - dirent.Type = linux.DT_REG + if len(p9ds) == 0 { + d.handleMu.RUnlock() + break } - dirents = append(dirents, dirent) - if realChildren != nil { - realChildren[p9d.Name] = struct{}{} + for _, p9d := range p9ds { + if p9d.Name == "." || p9d.Name == ".." { + continue + } + dirent := vfs.Dirent{ + Name: p9d.Name, + Ino: d.fs.inoFromQIDPath(p9d.QID.Path), + NextOff: int64(len(dirents) + 1), + } + // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or + // DMSOCKET. + switch p9d.Type { + case p9.TypeSymlink: + dirent.Type = linux.DT_LNK + case p9.TypeDir: + dirent.Type = linux.DT_DIR + default: + dirent.Type = linux.DT_REG + } + dirents = append(dirents, dirent) + if realChildren != nil { + realChildren[p9d.Name] = struct{}{} + } } + off = p9ds[len(p9ds)-1].Offset } - off = p9ds[len(p9ds)-1].Offset } } // Emit entries for synthetic children. diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 00228c469..23c8b8ce3 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -21,10 +21,12 @@ import ( "sync" "sync/atomic" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" "gvisor.dev/gvisor/pkg/sentry/fsmetric" @@ -53,9 +55,47 @@ func (fs *filesystem) Sync(ctx context.Context) error { // regardless. var retErr error + if fs.opts.lisaEnabled { + // Try accumulating all FDIDs to fsync and fsync then via one RPC as + // opposed to making an RPC per FDID. Passing a non-nil accFsyncFDIDs to + // dentry.syncCachedFile() and specialFileFD.sync() will cause them to not + // make an RPC, instead accumulate syncable FDIDs in the passed slice. + accFsyncFDIDs := make([]lisafs.FDID, 0, len(ds)+len(sffds)) + + // Sync syncable dentries. + for _, d := range ds { + if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil { + ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) + if retErr == nil { + retErr = err + } + } + } + + // Sync special files, which may be writable but do not use dentry shared + // handles (so they won't be synced by the above). + for _, sffd := range sffds { + if err := sffd.sync(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil { + ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) + if retErr == nil { + retErr = err + } + } + } + + if err := fs.clientLisa.SyncFDs(ctx, accFsyncFDIDs); err != nil { + ctx.Infof("gofer.filesystem.Sync: fs.fsyncMultipleFDLisa failed: %v", err) + if retErr == nil { + retErr = err + } + } + + return retErr + } + // Sync syncable dentries. for _, d := range ds { - if err := d.syncCachedFile(ctx, true /* forFilesystemSync */); err != nil { + if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil { ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) if retErr == nil { retErr = err @@ -66,7 +106,7 @@ func (fs *filesystem) Sync(ctx context.Context) error { // Sync special files, which may be writable but do not use dentry shared // handles (so they won't be synced by the above). for _, sffd := range sffds { - if err := sffd.sync(ctx, true /* forFilesystemSync */); err != nil { + if err := sffd.sync(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil { ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) if retErr == nil { retErr = err @@ -130,7 +170,7 @@ func putDentrySlice(ds *[]*dentry) { // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -// +checklocksrelease:fs.renameMu +// +checklocksreleaseread:fs.renameMu func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **[]*dentry) { fs.renameMu.RUnlock() if *dsp == nil { @@ -197,7 +237,13 @@ afterSymlink: rp.Advance() return d.parent, followedSymlink, nil } - child, err := fs.getChildLocked(ctx, d, name, ds) + var child *dentry + var err error + if fs.opts.lisaEnabled { + child, err = fs.getChildAndWalkPathLocked(ctx, d, rp, ds) + } else { + child, err = fs.getChildLocked(ctx, d, name, ds) + } if err != nil { return nil, false, err } @@ -219,6 +265,99 @@ afterSymlink: return child, followedSymlink, nil } +// Preconditions: +// * fs.opts.lisaEnabled. +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. +// * parent.isDir(). +// * parent and the dentry at name have been revalidated. +func (fs *filesystem) getChildAndWalkPathLocked(ctx context.Context, parent *dentry, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { + // Note that pit is a copy of the iterator that does not affect rp. + pit := rp.Pit() + first := pit.String() + if len(first) > maxFilenameLen { + return nil, linuxerr.ENAMETOOLONG + } + if child, ok := parent.children[first]; ok || parent.isSynthetic() { + if child == nil { + return nil, linuxerr.ENOENT + } + return child, nil + } + + // Walk as much of the path as possible in 1 RPC. + names := []string{first} + for pit = pit.Next(); pit.Ok(); pit = pit.Next() { + name := pit.String() + if name == "." { + continue + } + if name == ".." { + break + } + names = append(names, name) + } + status, inodes, err := parent.controlFDLisa.WalkMultiple(ctx, names) + if err != nil { + return nil, err + } + if len(inodes) == 0 { + parent.cacheNegativeLookupLocked(first) + return nil, linuxerr.ENOENT + } + + // Add the walked inodes into the dentry tree. + curParent := parent + curParentDirMuLock := func() { + if curParent != parent { + curParent.dirMu.Lock() + } + } + curParentDirMuUnlock := func() { + if curParent != parent { + curParent.dirMu.Unlock() // +checklocksforce: locked via curParentDirMuLock(). + } + } + var ret *dentry + var dentryCreationErr error + for i := range inodes { + if dentryCreationErr != nil { + fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD) + continue + } + + child, err := fs.newDentryLisa(ctx, &inodes[i]) + if err != nil { + fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD) + dentryCreationErr = err + continue + } + curParentDirMuLock() + curParent.cacheNewChildLocked(child, names[i]) + curParentDirMuUnlock() + // For now, child has 0 references, so our caller should call + // child.checkCachingLocked(). curParent gained a ref so we should also + // call curParent.checkCachingLocked() so it can be removed from the cache + // if needed. We only do that for the first iteration because all + // subsequent parents would have already been added to ds. + if i == 0 { + *ds = appendDentry(*ds, curParent) + } + *ds = appendDentry(*ds, child) + curParent = child + if i == 0 { + ret = child + } + } + + if status == lisafs.WalkComponentDoesNotExist && curParent.isDir() { + curParentDirMuLock() + curParent.cacheNegativeLookupLocked(names[len(inodes)]) + curParentDirMuUnlock() + } + return ret, dentryCreationErr +} + // getChildLocked returns a dentry representing the child of parent with the // given name. Returns ENOENT if the child doesn't exist. // @@ -227,7 +366,7 @@ afterSymlink: // * parent.dirMu must be locked. // * parent.isDir(). // * name is not "." or "..". -// * dentry at name has been revalidated +// * parent and the dentry at name have been revalidated. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if len(name) > maxFilenameLen { return nil, linuxerr.ENAMETOOLONG @@ -239,20 +378,35 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s return child, nil } - qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) - if err != nil { - if linuxerr.Equals(linuxerr.ENOENT, err) { - parent.cacheNegativeLookupLocked(name) + var child *dentry + if fs.opts.lisaEnabled { + childInode, err := parent.controlFDLisa.Walk(ctx, name) + if err != nil { + if linuxerr.Equals(linuxerr.ENOENT, err) { + parent.cacheNegativeLookupLocked(name) + } + return nil, err + } + // Create a new dentry representing the file. + child, err = fs.newDentryLisa(ctx, childInode) + if err != nil { + fs.clientLisa.CloseFDBatched(ctx, childInode.ControlFD) + return nil, err + } + } else { + qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) + if err != nil { + if linuxerr.Equals(linuxerr.ENOENT, err) { + parent.cacheNegativeLookupLocked(name) + } + return nil, err + } + // Create a new dentry representing the file. + child, err = fs.newDentry(ctx, file, qid, attrMask, &attr) + if err != nil { + file.close(ctx) + return nil, err } - return nil, err - } - - // Create a new dentry representing the file. - child, err := fs.newDentry(ctx, file, qid, attrMask, &attr) - if err != nil { - file.close(ctx) - delete(parent.children, name) - return nil, err } parent.cacheNewChildLocked(child, name) appendNewChildDentry(ds, parent, child) @@ -328,7 +482,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, // Preconditions: // * !rp.Done(). // * For the final path component in rp, !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) error, createInSyntheticDir func(parent *dentry, name string) error) error { +func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error), createInSyntheticDir func(parent *dentry, name string) error, updateChild func(child *dentry)) error { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) @@ -415,9 +569,26 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir // No cached dentry exists; however, in InteropModeShared there might still be // an existing file at name. Just attempt the file creation RPC anyways. If a // file does exist, the RPC will fail with EEXIST like we would have. - if err := createInRemoteDir(parent, name, &ds); err != nil { + lisaInode, err := createInRemoteDir(parent, name, &ds) + if err != nil { return err } + // lisafs may aggresively cache newly created inodes. This has helped reduce + // Walk RPCs in practice. + if lisaInode != nil { + child, err := fs.newDentryLisa(ctx, lisaInode) + if err != nil { + fs.clientLisa.CloseFDBatched(ctx, lisaInode.ControlFD) + return err + } + parent.cacheNewChildLocked(child, name) + appendNewChildDentry(&ds, parent, child) + + // lisafs may update dentry properties upon successful creation. + if updateChild != nil { + updateChild(child) + } + } if fs.opts.interop != InteropModeShared { if child, ok := parent.children[name]; ok && child == nil { // Delete the now-stale negative dentry. @@ -565,7 +736,11 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return linuxerr.ENOENT } } else if child == nil || !child.isSynthetic() { - err = parent.file.unlinkAt(ctx, name, flags) + if fs.opts.lisaEnabled { + err = parent.controlFDLisa.UnlinkAt(ctx, name, flags) + } else { + err = parent.file.unlinkAt(ctx, name, flags) + } if err != nil { if child != nil { vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above. @@ -658,40 +833,43 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, _ **[]*dentry) error { + err := fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, ds **[]*dentry) (*lisafs.Inode, error) { if rp.Mount() != vd.Mount() { - return linuxerr.EXDEV + return nil, linuxerr.EXDEV } d := vd.Dentry().Impl().(*dentry) if d.isDir() { - return linuxerr.EPERM + return nil, linuxerr.EPERM } gid := auth.KGID(atomic.LoadUint32(&d.gid)) uid := auth.KUID(atomic.LoadUint32(&d.uid)) mode := linux.FileMode(atomic.LoadUint32(&d.mode)) if err := vfs.MayLink(rp.Credentials(), mode, uid, gid); err != nil { - return err + return nil, err } if d.nlink == 0 { - return linuxerr.ENOENT + return nil, linuxerr.ENOENT } if d.nlink == math.MaxUint32 { - return linuxerr.EMLINK + return nil, linuxerr.EMLINK } - if err := parent.file.link(ctx, d.file, childName); err != nil { - return err + if fs.opts.lisaEnabled { + return parent.controlFDLisa.LinkAt(ctx, d.controlFDLisa.ID(), childName) } + return nil, parent.file.link(ctx, d.file, childName) + }, nil, nil) + if err == nil { // Success! - atomic.AddUint32(&d.nlink, 1) - return nil - }, nil) + vd.Dentry().Impl().(*dentry).incLinks() + } + return err } // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) { // If the parent is a setgid directory, use the parent's GID // rather than the caller's and enable setgid. kgid := creds.EffectiveKGID @@ -700,23 +878,37 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kgid = auth.KGID(atomic.LoadUint32(&parent.gid)) mode |= linux.S_ISGID } - if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil { - if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) { - return err + var ( + childDirInode *lisafs.Inode + err error + ) + if fs.opts.lisaEnabled { + childDirInode, err = parent.controlFDLisa.MkdirAt(ctx, name, mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid)) + } else { + _, err = parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)) + } + if err == nil { + if fs.opts.interop != InteropModeShared { + parent.incLinks() } - ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err) - parent.createSyntheticChildLocked(&createSyntheticOpts{ - name: name, - mode: linux.S_IFDIR | opts.Mode, - kuid: creds.EffectiveKUID, - kgid: creds.EffectiveKGID, - }) - *ds = appendDentry(*ds, parent) + return childDirInode, nil + } + + if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) { + return nil, err } + ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err) + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: linux.S_IFDIR | opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + }) + *ds = appendDentry(*ds, parent) if fs.opts.interop != InteropModeShared { parent.incLinks() } - return nil + return nil, nil }, func(parent *dentry, name string) error { if !opts.ForSyntheticMountpoint { // Can't create non-synthetic files in synthetic directories. @@ -730,16 +922,26 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v }) parent.incLinks() return nil - }) + }, nil) } // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) { creds := rp.Credentials() - _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - if !linuxerr.Equals(linuxerr.EPERM, err) { - return err + var ( + childInode *lisafs.Inode + err error + ) + if fs.opts.lisaEnabled { + childInode, err = parent.controlFDLisa.MknodAt(ctx, name, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID), opts.DevMinor, opts.DevMajor) + } else { + _, err = parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) + } + if err == nil { + return childInode, nil + } else if !linuxerr.Equals(linuxerr.EPERM, err) { + return nil, err } // EPERM means that gofer does not allow creating a socket or pipe. Fallback @@ -750,10 +952,10 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v switch { case err == nil: // Step succeeded, another file exists. - return linuxerr.EEXIST + return nil, linuxerr.EEXIST case !linuxerr.Equals(linuxerr.ENOENT, err): // Unexpected error. - return err + return nil, err } switch opts.Mode.FileType() { @@ -766,7 +968,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v endpoint: opts.Endpoint, }) *ds = appendDentry(*ds, parent) - return nil + return nil, nil case linux.S_IFIFO: parent.createSyntheticChildLocked(&createSyntheticOpts{ name: name, @@ -776,11 +978,11 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) *ds = appendDentry(*ds, parent) - return nil + return nil, nil } // Retain error from gofer if synthetic file cannot be created internally. - return linuxerr.EPERM - }, nil) + return nil, linuxerr.EPERM + }, nil, nil) } // OpenAt implements vfs.FilesystemImpl.OpenAt. @@ -986,6 +1188,23 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio if opts.Flags&linux.O_DIRECT != 0 { return nil, linuxerr.EINVAL } + if d.fs.opts.lisaEnabled { + // Note that special value of linux.SockType = 0 is interpreted by lisafs + // as "do not care about the socket type". Analogous to p9.AnonymousSocket. + sockFD, err := d.controlFDLisa.Connect(ctx, 0 /* sockType */) + if err != nil { + return nil, err + } + fd, err := host.NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), sockFD, &host.NewFDOptions{ + HaveFlags: true, + Flags: opts.Flags, + }) + if err != nil { + unix.Close(sockFD) + return nil, err + } + return fd, nil + } fdObj, err := d.file.connect(ctx, p9.AnonymousSocket) if err != nil { return nil, err @@ -998,6 +1217,7 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio fdObj.Close() return nil, err } + // Ownership has been transferred to fd. fdObj.Release() return fd, nil } @@ -1017,7 +1237,13 @@ func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs. // since closed its end. isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0 retry: - h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) + var h handle + var err error + if d.fs.opts.lisaEnabled { + h, err = openHandleLisa(ctx, d.controlFDLisa, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) + } else { + h, err = openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) + } if err != nil { if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && linuxerr.Equals(linuxerr.ENXIO, err) { // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails @@ -1061,18 +1287,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } defer mnt.EndWrite() - // 9P2000.L's lcreate takes a fid representing the parent directory, and - // converts it into an open fid representing the created file, so we need - // to duplicate the directory fid first. - _, dirfile, err := d.file.walk(ctx, nil) - if err != nil { - return nil, err - } creds := rp.Credentials() name := rp.Component() - // We only want the access mode for creating the file. - createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask - // If the parent is a setgid directory, use the parent's GID rather // than the caller's. kgid := creds.EffectiveKGID @@ -1080,51 +1296,87 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving kgid = auth.KGID(atomic.LoadUint32(&d.gid)) } - fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)) - if err != nil { - dirfile.close(ctx) - return nil, err - } - // Then we need to walk to the file we just created to get a non-open fid - // representing it, and to get its metadata. This must use d.file since, as - // explained above, dirfile was invalidated by dirfile.Create(). - _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name) - if err != nil { - openFile.close(ctx) - if fdobj != nil { - fdobj.Close() + var child *dentry + var openP9File p9file + openLisaFD := lisafs.InvalidFDID + openHostFD := int32(-1) + if d.fs.opts.lisaEnabled { + ino, openFD, hostFD, err := d.controlFDLisa.OpenCreateAt(ctx, name, opts.Flags&linux.O_ACCMODE, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid)) + if err != nil { + return nil, err + } + openHostFD = int32(hostFD) + openLisaFD = openFD + + child, err = d.fs.newDentryLisa(ctx, &ino) + if err != nil { + d.fs.clientLisa.CloseFDBatched(ctx, ino.ControlFD) + d.fs.clientLisa.CloseFDBatched(ctx, openFD) + if hostFD >= 0 { + unix.Close(hostFD) + } + return nil, err + } + } else { + // 9P2000.L's lcreate takes a fid representing the parent directory, and + // converts it into an open fid representing the created file, so we need + // to duplicate the directory fid first. + _, dirfile, err := d.file.walk(ctx, nil) + if err != nil { + return nil, err + } + // We only want the access mode for creating the file. + createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask + + fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)) + if err != nil { + dirfile.close(ctx) + return nil, err + } + // Then we need to walk to the file we just created to get a non-open fid + // representing it, and to get its metadata. This must use d.file since, as + // explained above, dirfile was invalidated by dirfile.Create(). + _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name) + if err != nil { + openFile.close(ctx) + if fdobj != nil { + fdobj.Close() + } + return nil, err + } + + // Construct the new dentry. + child, err = d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr) + if err != nil { + nonOpenFile.close(ctx) + openFile.close(ctx) + if fdobj != nil { + fdobj.Close() + } + return nil, err } - return nil, err - } - // Construct the new dentry. - child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr) - if err != nil { - nonOpenFile.close(ctx) - openFile.close(ctx) if fdobj != nil { - fdobj.Close() + openHostFD = int32(fdobj.Release()) } - return nil, err + openP9File = openFile } // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { - openFD := int32(-1) - if fdobj != nil { - openFD = int32(fdobj.Release()) - } child.handleMu.Lock() if vfs.MayReadFileWithOpenFlags(opts.Flags) { - child.readFile = openFile - if fdobj != nil { - child.readFD = openFD - child.mmapFD = openFD + child.readFile = openP9File + child.readFDLisa = d.fs.clientLisa.NewFD(openLisaFD) + if openHostFD != -1 { + child.readFD = openHostFD + child.mmapFD = openHostFD } } if vfs.MayWriteFileWithOpenFlags(opts.Flags) { - child.writeFile = openFile - child.writeFD = openFD + child.writeFile = openP9File + child.writeFDLisa = d.fs.clientLisa.NewFD(openLisaFD) + child.writeFD = openHostFD } child.handleMu.Unlock() } @@ -1146,11 +1398,9 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving childVFSFD = &fd.vfsfd } else { h := handle{ - file: openFile, - fd: -1, - } - if fdobj != nil { - h.fd = int32(fdobj.Release()) + file: openP9File, + fdLisa: d.fs.clientLisa.NewFD(openLisaFD), + fd: openHostFD, } fd, err := newSpecialFileFD(h, mnt, child, opts.Flags) if err != nil { @@ -1304,7 +1554,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Update the remote filesystem. if !renamed.isSynthetic() { - if err := renamed.file.rename(ctx, newParent.file, newName); err != nil { + if fs.opts.lisaEnabled { + err = renamed.controlFDLisa.RenameTo(ctx, newParent.controlFDLisa.ID(), newName) + } else { + err = renamed.file.rename(ctx, newParent.file, newName) + } + if err != nil { vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) return err } @@ -1315,7 +1570,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if replaced.isDir() { flags = linux.AT_REMOVEDIR } - if err := newParent.file.unlinkAt(ctx, newName, flags); err != nil { + if fs.opts.lisaEnabled { + err = newParent.controlFDLisa.UnlinkAt(ctx, newName, flags) + } else { + err = newParent.file.unlinkAt(ctx, newName, flags) + } + if err != nil { vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) return err } @@ -1431,6 +1691,28 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu for d.isSynthetic() { d = d.parent } + if fs.opts.lisaEnabled { + var statFS lisafs.StatFS + if err := d.controlFDLisa.StatFSTo(ctx, &statFS); err != nil { + return linux.Statfs{}, err + } + if statFS.NameLength > maxFilenameLen { + statFS.NameLength = maxFilenameLen + } + return linux.Statfs{ + // This is primarily for distinguishing a gofer file system in + // tests. Testing is important, so instead of defining + // something completely random, use a standard value. + Type: linux.V9FS_MAGIC, + BlockSize: statFS.BlockSize, + Blocks: statFS.Blocks, + BlocksFree: statFS.BlocksFree, + BlocksAvailable: statFS.BlocksAvailable, + Files: statFS.Files, + FilesFree: statFS.FilesFree, + NameLength: statFS.NameLength, + }, nil + } fsstat, err := d.file.statFS(ctx) if err != nil { return linux.Statfs{}, err @@ -1456,11 +1738,21 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) { creds := rp.Credentials() + if fs.opts.lisaEnabled { + return parent.controlFDLisa.SymlinkAt(ctx, name, target, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID)) + } _, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - return err - }, nil) + return nil, err + }, nil, func(child *dentry) { + if fs.opts.interop != InteropModeShared { + // lisafs caches the symlink target on creation. In practice, this + // helps avoid a lot of ReadLink RPCs. + child.haveTarget = true + child.target = target + } + }) } // UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. @@ -1505,7 +1797,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si if err != nil { return nil, err } - return d.listXattr(ctx, rp.Credentials(), size) + return d.listXattr(ctx, size) } // GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. @@ -1612,6 +1904,9 @@ func (fs *filesystem) MountOptions() string { if fs.opts.overlayfsStaleRead { optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil}) } + if fs.opts.lisaEnabled { + optsKV = append(optsKV, mopt{moptLisafs, nil}) + } opts := make([]string, 0, len(optsKV)) for _, opt := range optsKV { diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 43440ec19..b98825e26 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -48,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" refs_vfs1 "gvisor.dev/gvisor/pkg/refs" @@ -83,6 +84,7 @@ const ( moptForcePageCache = "force_page_cache" moptLimitHostFDTranslation = "limit_host_fd_translation" moptOverlayfsStaleRead = "overlayfs_stale_read" + moptLisafs = "lisafs" ) // Valid values for the "cache" mount option. @@ -118,6 +120,10 @@ type filesystem struct { // client is the client used by this filesystem. client is immutable. client *p9.Client `state:"nosave"` + // clientLisa is the client used for communicating with the server when + // lisafs is enabled. lisafsCient is immutable. + clientLisa *lisafs.Client `state:"nosave"` + // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -161,6 +167,12 @@ type filesystem struct { inoMu sync.Mutex `state:"nosave"` inoByQIDPath map[uint64]uint64 `state:"nosave"` + // inoByKey is the same as inoByQIDPath but only used by lisafs. It helps + // identify inodes based on the device ID and host inode number provided + // by the gofer process. It is not preserved across checkpoint/restore for + // the same reason as above. inoByKey is protected by inoMu. + inoByKey map[inoKey]uint64 `state:"nosave"` + // lastIno is the last inode number assigned to a file. lastIno is accessed // using atomic memory operations. lastIno uint64 @@ -214,6 +226,10 @@ type filesystemOptions struct { // way that application FDs representing "special files" such as sockets // do. Note that this disables client caching and mmap for regular files. regularFilesUseSpecialFileFD bool + + // lisaEnabled indicates whether the client will use lisafs protocol to + // communicate with the server instead of 9P. + lisaEnabled bool } // InteropMode controls the client's interaction with other remote filesystem @@ -427,6 +443,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt delete(mopts, moptOverlayfsStaleRead) fsopts.overlayfsStaleRead = true } + if lisafs, ok := mopts[moptLisafs]; ok { + delete(mopts, moptLisafs) + fsopts.lisaEnabled, err = strconv.ParseBool(lisafs) + if err != nil { + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid lisafs option: %s", lisafs) + return nil, nil, linuxerr.EINVAL + } + } // fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying // "cache=none". @@ -458,44 +482,83 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt syncableDentries: make(map[*dentry]struct{}), specialFileFDs: make(map[*specialFileFD]struct{}), inoByQIDPath: make(map[uint64]uint64), + inoByKey: make(map[inoKey]uint64), } fs.vfsfs.Init(vfsObj, &fstype, fs) + if err := fs.initClientAndRoot(ctx); err != nil { + fs.vfsfs.DecRef(ctx) + return nil, nil, err + } + + return &fs.vfsfs, &fs.root.vfsd, nil +} + +func (fs *filesystem) initClientAndRoot(ctx context.Context) error { + var err error + if fs.opts.lisaEnabled { + var rootInode *lisafs.Inode + rootInode, err = fs.initClientLisa(ctx) + if err != nil { + return err + } + fs.root, err = fs.newDentryLisa(ctx, rootInode) + if err != nil { + fs.clientLisa.CloseFDBatched(ctx, rootInode.ControlFD) + } + } else { + fs.root, err = fs.initClient(ctx) + } + + // Set the root's reference count to 2. One reference is returned to the + // caller, and the other is held by fs to prevent the root from being "cached" + // and subsequently evicted. + if err == nil { + fs.root.refs = 2 + } + return err +} + +func (fs *filesystem) initClientLisa(ctx context.Context) (*lisafs.Inode, error) { + sock, err := unet.NewSocket(fs.opts.fd) + if err != nil { + return nil, err + } + + var rootInode *lisafs.Inode + ctx.UninterruptibleSleepStart(false) + fs.clientLisa, rootInode, err = lisafs.NewClient(sock, fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + return rootInode, err +} + +func (fs *filesystem) initClient(ctx context.Context) (*dentry, error) { // Connect to the server. if err := fs.dial(ctx); err != nil { - return nil, nil, err + return nil, err } // Perform attach to obtain the filesystem root. ctx.UninterruptibleSleepStart(false) - attached, err := fs.client.Attach(fsopts.aname) + attached, err := fs.client.Attach(fs.opts.aname) ctx.UninterruptibleSleepFinish(false) if err != nil { - fs.vfsfs.DecRef(ctx) - return nil, nil, err + return nil, err } attachFile := p9file{attached} qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) if err != nil { attachFile.close(ctx) - fs.vfsfs.DecRef(ctx) - return nil, nil, err + return nil, err } // Construct the root dentry. root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) if err != nil { attachFile.close(ctx) - fs.vfsfs.DecRef(ctx) - return nil, nil, err + return nil, err } - // Set the root's reference count to 2. One reference is returned to the - // caller, and the other is held by fs to prevent the root from being "cached" - // and subsequently evicted. - root.refs = 2 - fs.root = root - - return &fs.vfsfs, &root.vfsd, nil + return root, nil } func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { @@ -613,7 +676,11 @@ func (fs *filesystem) Release(ctx context.Context) { if !fs.iopts.LeakConnection { // Close the connection to the server. This implicitly clunks all fids. - fs.client.Close() + if fs.opts.lisaEnabled { + fs.clientLisa.Close() + } else { + fs.client.Close() + } } fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) @@ -644,6 +711,23 @@ func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { } } +// inoKey is the key used to identify the inode backed by this dentry. +// +// +stateify savable +type inoKey struct { + ino uint64 + devMinor uint32 + devMajor uint32 +} + +func inoKeyFromStat(stat *linux.Statx) inoKey { + return inoKey{ + ino: stat.Ino, + devMinor: stat.DevMinor, + devMajor: stat.DevMajor, + } +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -674,6 +758,9 @@ type dentry struct { // qidPath is the p9.QID.Path for this file. qidPath is immutable. qidPath uint64 + // inoKey is used to identify this dentry's inode. + inoKey inoKey + // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file @@ -681,6 +768,14 @@ type dentry struct { // only files that can be synthetic are sockets, pipes, and directories. file p9file `state:"nosave"` + // controlFDLisa is used by lisafs to perform path based operations on this + // dentry. + // + // if !controlFDLisa.Ok(), this dentry represents a synthetic file, i.e. a + // file that does not exist on the remote filesystem. As of this writing, the + // only files that can be synthetic are sockets, pipes, and directories. + controlFDLisa lisafs.ClientFD `state:"nosave"` + // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. deleted uint32 @@ -791,12 +886,14 @@ type dentry struct { // always either -1 or equal to readFD; if !writeFile.isNil() (the file has // been opened for writing), it is additionally either -1 or equal to // writeFD. - handleMu sync.RWMutex `state:"nosave"` - readFile p9file `state:"nosave"` - writeFile p9file `state:"nosave"` - readFD int32 `state:"nosave"` - writeFD int32 `state:"nosave"` - mmapFD int32 `state:"nosave"` + handleMu sync.RWMutex `state:"nosave"` + readFile p9file `state:"nosave"` + writeFile p9file `state:"nosave"` + readFDLisa lisafs.ClientFD `state:"nosave"` + writeFDLisa lisafs.ClientFD `state:"nosave"` + readFD int32 `state:"nosave"` + writeFD int32 `state:"nosave"` + mmapFD int32 `state:"nosave"` dataMu sync.RWMutex `state:"nosave"` @@ -920,6 +1017,79 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma return d, nil } +func (fs *filesystem) newDentryLisa(ctx context.Context, ino *lisafs.Inode) (*dentry, error) { + if ino.Stat.Mask&linux.STATX_TYPE == 0 { + ctx.Warningf("can't create gofer.dentry without file type") + return nil, linuxerr.EIO + } + if ino.Stat.Mode&linux.FileTypeMask == linux.ModeRegular && ino.Stat.Mask&linux.STATX_SIZE == 0 { + ctx.Warningf("can't create regular file gofer.dentry without file size") + return nil, linuxerr.EIO + } + + inoKey := inoKeyFromStat(&ino.Stat) + d := &dentry{ + fs: fs, + inoKey: inoKey, + ino: fs.inoFromKey(inoKey), + mode: uint32(ino.Stat.Mode), + uid: uint32(fs.opts.dfltuid), + gid: uint32(fs.opts.dfltgid), + blockSize: hostarch.PageSize, + readFD: -1, + writeFD: -1, + mmapFD: -1, + controlFDLisa: fs.clientLisa.NewFD(ino.ControlFD), + } + + d.pf.dentry = d + if ino.Stat.Mask&linux.STATX_UID != 0 { + d.uid = dentryUIDFromLisaUID(lisafs.UID(ino.Stat.UID)) + } + if ino.Stat.Mask&linux.STATX_GID != 0 { + d.gid = dentryGIDFromLisaGID(lisafs.GID(ino.Stat.GID)) + } + if ino.Stat.Mask&linux.STATX_SIZE != 0 { + d.size = ino.Stat.Size + } + if ino.Stat.Blksize != 0 { + d.blockSize = ino.Stat.Blksize + } + if ino.Stat.Mask&linux.STATX_ATIME != 0 { + d.atime = dentryTimestampFromLisa(ino.Stat.Atime) + } + if ino.Stat.Mask&linux.STATX_MTIME != 0 { + d.mtime = dentryTimestampFromLisa(ino.Stat.Mtime) + } + if ino.Stat.Mask&linux.STATX_CTIME != 0 { + d.ctime = dentryTimestampFromLisa(ino.Stat.Ctime) + } + if ino.Stat.Mask&linux.STATX_BTIME != 0 { + d.btime = dentryTimestampFromLisa(ino.Stat.Btime) + } + if ino.Stat.Mask&linux.STATX_NLINK != 0 { + d.nlink = ino.Stat.Nlink + } + d.vfsd.Init(d) + refsvfs2.Register(d) + fs.syncMu.Lock() + fs.syncableDentries[d] = struct{}{} + fs.syncMu.Unlock() + return d, nil +} + +func (fs *filesystem) inoFromKey(key inoKey) uint64 { + fs.inoMu.Lock() + defer fs.inoMu.Unlock() + + if ino, ok := fs.inoByKey[key]; ok { + return ino + } + ino := fs.nextIno() + fs.inoByKey[key] = ino + return ino +} + func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 { fs.inoMu.Lock() defer fs.inoMu.Unlock() @@ -936,7 +1106,7 @@ func (fs *filesystem) nextIno() uint64 { } func (d *dentry) isSynthetic() bool { - return d.file.isNil() + return !d.isControlFileOk() } func (d *dentry) cachedMetadataAuthoritative() bool { @@ -986,6 +1156,50 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } } +// updateFromLisaStatLocked is called to update d's metadata after an update +// from the remote filesystem. +// Precondition: d.metadataMu must be locked. +// +checklocks:d.metadataMu +func (d *dentry) updateFromLisaStatLocked(stat *linux.Statx) { + if stat.Mask&linux.STATX_TYPE != 0 { + if got, want := stat.Mode&linux.FileTypeMask, d.fileType(); uint32(got) != want { + panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got)) + } + } + if stat.Mask&linux.STATX_MODE != 0 { + atomic.StoreUint32(&d.mode, uint32(stat.Mode)) + } + if stat.Mask&linux.STATX_UID != 0 { + atomic.StoreUint32(&d.uid, dentryUIDFromLisaUID(lisafs.UID(stat.UID))) + } + if stat.Mask&linux.STATX_GID != 0 { + atomic.StoreUint32(&d.uid, dentryGIDFromLisaGID(lisafs.GID(stat.GID))) + } + if stat.Blksize != 0 { + atomic.StoreUint32(&d.blockSize, stat.Blksize) + } + // Don't override newer client-defined timestamps with old server-defined + // ones. + if stat.Mask&linux.STATX_ATIME != 0 && atomic.LoadUint32(&d.atimeDirty) == 0 { + atomic.StoreInt64(&d.atime, dentryTimestampFromLisa(stat.Atime)) + } + if stat.Mask&linux.STATX_MTIME != 0 && atomic.LoadUint32(&d.mtimeDirty) == 0 { + atomic.StoreInt64(&d.mtime, dentryTimestampFromLisa(stat.Mtime)) + } + if stat.Mask&linux.STATX_CTIME != 0 { + atomic.StoreInt64(&d.ctime, dentryTimestampFromLisa(stat.Ctime)) + } + if stat.Mask&linux.STATX_BTIME != 0 { + atomic.StoreInt64(&d.btime, dentryTimestampFromLisa(stat.Btime)) + } + if stat.Mask&linux.STATX_NLINK != 0 { + atomic.StoreUint32(&d.nlink, stat.Nlink) + } + if stat.Mask&linux.STATX_SIZE != 0 { + d.updateSizeLocked(stat.Size) + } +} + // Preconditions: !d.isSynthetic(). // Preconditions: d.metadataMu is locked. // +checklocks:d.metadataMu @@ -995,6 +1209,9 @@ func (d *dentry) refreshSizeLocked(ctx context.Context) error { if d.writeFD < 0 { d.handleMu.RUnlock() // Ask the gofer if we don't have a host FD. + if d.fs.opts.lisaEnabled { + return d.updateFromStatLisaLocked(ctx, nil) + } return d.updateFromGetattrLocked(ctx, p9file{}) } @@ -1014,6 +1231,9 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { // updating stale attributes in d.updateFromP9AttrsLocked(). d.metadataMu.Lock() defer d.metadataMu.Unlock() + if d.fs.opts.lisaEnabled { + return d.updateFromStatLisaLocked(ctx, nil) + } return d.updateFromGetattrLocked(ctx, p9file{}) } @@ -1021,6 +1241,45 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { // * !d.isSynthetic(). // * d.metadataMu is locked. // +checklocks:d.metadataMu +func (d *dentry) updateFromStatLisaLocked(ctx context.Context, fdLisa *lisafs.ClientFD) error { + handleMuRLocked := false + if fdLisa == nil { + // Use open FDs in preferenece to the control FD. This may be significantly + // more efficient in some implementations. Prefer a writable FD over a + // readable one since some filesystem implementations may update a writable + // FD's metadata after writes, without making metadata updates immediately + // visible to read-only FDs representing the same file. + d.handleMu.RLock() + switch { + case d.writeFDLisa.Ok(): + fdLisa = &d.writeFDLisa + handleMuRLocked = true + case d.readFDLisa.Ok(): + fdLisa = &d.readFDLisa + handleMuRLocked = true + default: + fdLisa = &d.controlFDLisa + d.handleMu.RUnlock() + } + } + + var stat linux.Statx + err := fdLisa.StatTo(ctx, &stat) + if handleMuRLocked { + // handleMu must be released before updateFromLisaStatLocked(). + d.handleMu.RUnlock() // +checklocksforce: complex case. + } + if err != nil { + return err + } + d.updateFromLisaStatLocked(&stat) + return nil +} + +// Preconditions: +// * !d.isSynthetic(). +// * d.metadataMu is locked. +// +checklocks:d.metadataMu func (d *dentry) updateFromGetattrLocked(ctx context.Context, file p9file) error { handleMuRLocked := false if file.isNil() { @@ -1160,6 +1419,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs } } + // failureMask indicates which attributes could not be set on the remote + // filesystem. p9 returns an error if any of the attributes could not be set + // but that leads to inconsistency as the server could have set a few + // attributes successfully but a later failure will cause the successful ones + // to not be updated in the dentry cache. + var failureMask uint32 + var failureErr error if !d.isSynthetic() { if stat.Mask != 0 { if stat.Mask&linux.STATX_SIZE != 0 { @@ -1169,35 +1435,50 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // the remote file has been truncated). d.dataMu.Lock() } - if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0, - UID: stat.Mask&linux.STATX_UID != 0, - GID: stat.Mask&linux.STATX_GID != 0, - Size: stat.Mask&linux.STATX_SIZE != 0, - ATime: stat.Mask&linux.STATX_ATIME != 0, - MTime: stat.Mask&linux.STATX_MTIME != 0, - ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW, - MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW, - }, p9.SetAttr{ - Permissions: p9.FileMode(stat.Mode), - UID: p9.UID(stat.UID), - GID: p9.GID(stat.GID), - Size: stat.Size, - ATimeSeconds: uint64(stat.Atime.Sec), - ATimeNanoSeconds: uint64(stat.Atime.Nsec), - MTimeSeconds: uint64(stat.Mtime.Sec), - MTimeNanoSeconds: uint64(stat.Mtime.Nsec), - }); err != nil { - if stat.Mask&linux.STATX_SIZE != 0 { - d.dataMu.Unlock() // +checklocksforce: locked conditionally above + if d.fs.opts.lisaEnabled { + var err error + failureMask, failureErr, err = d.controlFDLisa.SetStat(ctx, stat) + if err != nil { + if stat.Mask&linux.STATX_SIZE != 0 { + d.dataMu.Unlock() // +checklocksforce: locked conditionally above + } + return err + } + } else { + if err := d.file.setAttr(ctx, p9.SetAttrMask{ + Permissions: stat.Mask&linux.STATX_MODE != 0, + UID: stat.Mask&linux.STATX_UID != 0, + GID: stat.Mask&linux.STATX_GID != 0, + Size: stat.Mask&linux.STATX_SIZE != 0, + ATime: stat.Mask&linux.STATX_ATIME != 0, + MTime: stat.Mask&linux.STATX_MTIME != 0, + ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW, + MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW, + }, p9.SetAttr{ + Permissions: p9.FileMode(stat.Mode), + UID: p9.UID(stat.UID), + GID: p9.GID(stat.GID), + Size: stat.Size, + ATimeSeconds: uint64(stat.Atime.Sec), + ATimeNanoSeconds: uint64(stat.Atime.Nsec), + MTimeSeconds: uint64(stat.Mtime.Sec), + MTimeNanoSeconds: uint64(stat.Mtime.Nsec), + }); err != nil { + if stat.Mask&linux.STATX_SIZE != 0 { + d.dataMu.Unlock() // +checklocksforce: locked conditionally above + } + return err } - return err } if stat.Mask&linux.STATX_SIZE != 0 { - // d.size should be kept up to date, and privatized - // copy-on-write mappings of truncated pages need to be - // invalidated, even if InteropModeShared is in effect. - d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above + if failureMask&linux.STATX_SIZE == 0 { + // d.size should be kept up to date, and privatized + // copy-on-write mappings of truncated pages need to be + // invalidated, even if InteropModeShared is in effect. + d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above + } else { + d.dataMu.Unlock() // +checklocksforce: locked conditionally above + } } } if d.fs.opts.interop == InteropModeShared { @@ -1208,13 +1489,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - if stat.Mask&linux.STATX_MODE != 0 { + if stat.Mask&linux.STATX_MODE != 0 && failureMask&linux.STATX_MODE == 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } - if stat.Mask&linux.STATX_UID != 0 { + if stat.Mask&linux.STATX_UID != 0 && failureMask&linux.STATX_UID == 0 { atomic.StoreUint32(&d.uid, stat.UID) } - if stat.Mask&linux.STATX_GID != 0 { + if stat.Mask&linux.STATX_GID != 0 && failureMask&linux.STATX_GID == 0 { atomic.StoreUint32(&d.gid, stat.GID) } // Note that stat.Atime.Nsec and stat.Mtime.Nsec can't be UTIME_NOW because @@ -1222,15 +1503,19 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // stat.Mtime to client-local timestamps above, and if // !d.cachedMetadataAuthoritative() then we returned after calling // d.file.setAttr(). For the same reason, now must have been initialized. - if stat.Mask&linux.STATX_ATIME != 0 { + if stat.Mask&linux.STATX_ATIME != 0 && failureMask&linux.STATX_ATIME == 0 { atomic.StoreInt64(&d.atime, stat.Atime.ToNsec()) atomic.StoreUint32(&d.atimeDirty, 0) } - if stat.Mask&linux.STATX_MTIME != 0 { + if stat.Mask&linux.STATX_MTIME != 0 && failureMask&linux.STATX_MTIME == 0 { atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec()) atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) + if failureMask != 0 { + // Setting some attribute failed on the remote filesystem. + return failureErr + } return nil } @@ -1310,7 +1595,10 @@ func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats // (b/148380782). Allow all other extended attributes to be passed through // to the remote filesystem. This is inconsistent with Linux's 9p client, // but consistent with other filesystems (e.g. FUSE). - if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) { + // + // NOTE(b/202533394): Also disallow "trusted" namespace for now. This is + // consistent with the VFS1 gofer client. + if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) || strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { return linuxerr.EOPNOTSUPP } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) @@ -1346,6 +1634,20 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 { return uint32(gid) } +func dentryUIDFromLisaUID(uid lisafs.UID) uint32 { + if !uid.Ok() { + return uint32(auth.OverflowUID) + } + return uint32(uid) +} + +func dentryGIDFromLisaGID(gid lisafs.GID) uint32 { + if !gid.Ok() { + return uint32(auth.OverflowGID) + } + return uint32(gid) +} + // IncRef implements vfs.DentryImpl.IncRef. func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against @@ -1654,15 +1956,24 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.dirty.RemoveAll() } d.dataMu.Unlock() - // Clunk open fids and close open host FDs. - if !d.readFile.isNil() { - _ = d.readFile.close(ctx) - } - if !d.writeFile.isNil() && d.readFile != d.writeFile { - _ = d.writeFile.close(ctx) + if d.fs.opts.lisaEnabled { + if d.readFDLisa.Ok() && d.readFDLisa.ID() != d.writeFDLisa.ID() { + d.readFDLisa.CloseBatched(ctx) + } + if d.writeFDLisa.Ok() { + d.writeFDLisa.CloseBatched(ctx) + } + } else { + // Clunk open fids and close open host FDs. + if !d.readFile.isNil() { + _ = d.readFile.close(ctx) + } + if !d.writeFile.isNil() && d.readFile != d.writeFile { + _ = d.writeFile.close(ctx) + } + d.readFile = p9file{} + d.writeFile = p9file{} } - d.readFile = p9file{} - d.writeFile = p9file{} if d.readFD >= 0 { _ = unix.Close(int(d.readFD)) } @@ -1674,7 +1985,7 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.mmapFD = -1 d.handleMu.Unlock() - if !d.file.isNil() { + if d.isControlFileOk() { // Note that it's possible that d.atimeDirty or d.mtimeDirty are true, // i.e. client and server timestamps may differ (because e.g. a client // write was serviced by the page cache, and only written back to the @@ -1683,10 +1994,16 @@ func (d *dentry) destroyLocked(ctx context.Context) { // instantiated for the same file would remain coherent. Unfortunately, // this turns out to be too expensive in many cases, so for now we // don't do this. - if err := d.file.close(ctx); err != nil { - log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err) + + // Close the control FD. + if d.fs.opts.lisaEnabled { + d.controlFDLisa.CloseBatched(ctx) + } else { + if err := d.file.close(ctx); err != nil { + log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err) + } + d.file = p9file{} } - d.file = p9file{} // Remove d from the set of syncable dentries. d.fs.syncMu.Lock() @@ -1712,10 +2029,29 @@ func (d *dentry) setDeleted() { atomic.StoreUint32(&d.deleted, 1) } -func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { - if d.file.isNil() { +func (d *dentry) isControlFileOk() bool { + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.Ok() + } + return !d.file.isNil() +} + +func (d *dentry) isReadFileOk() bool { + if d.fs.opts.lisaEnabled { + return d.readFDLisa.Ok() + } + return !d.readFile.isNil() +} + +func (d *dentry) listXattr(ctx context.Context, size uint64) ([]string, error) { + if !d.isControlFileOk() { return nil, nil } + + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.ListXattr(ctx, size) + } + xattrMap, err := d.file.listXattr(ctx, size) if err != nil { return nil, err @@ -1728,32 +2064,41 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui } func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { - if d.file.isNil() { + if !d.isControlFileOk() { return "", linuxerr.ENODATA } if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { return "", err } + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.GetXattr(ctx, opts.Name, opts.Size) + } return d.file.getXattr(ctx, opts.Name, opts.Size) } func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error { - if d.file.isNil() { + if !d.isControlFileOk() { return linuxerr.EPERM } if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { return err } + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.SetXattr(ctx, opts.Name, opts.Value, opts.Flags) + } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error { - if d.file.isNil() { + if !d.isControlFileOk() { return linuxerr.EPERM } if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { return err } + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.RemoveXattr(ctx, name) + } return d.file.removeXattr(ctx, name) } @@ -1765,19 +2110,30 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // O_TRUNC). if !trunc { d.handleMu.RLock() - if (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) { + var canReuseCurHandle bool + if d.fs.opts.lisaEnabled { + canReuseCurHandle = (!read || d.readFDLisa.Ok()) && (!write || d.writeFDLisa.Ok()) + } else { + canReuseCurHandle = (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) + } + d.handleMu.RUnlock() + if canReuseCurHandle { // Current handles are sufficient. - d.handleMu.RUnlock() return nil } - d.handleMu.RUnlock() } var fdsToCloseArr [2]int32 fdsToClose := fdsToCloseArr[:0] invalidateTranslations := false d.handleMu.Lock() - if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc { + var needNewHandle bool + if d.fs.opts.lisaEnabled { + needNewHandle = (read && !d.readFDLisa.Ok()) || (write && !d.writeFDLisa.Ok()) || trunc + } else { + needNewHandle = (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc + } + if needNewHandle { // Get a new handle. If this file has been opened for both reading and // writing, try to get a single handle that is usable for both: // @@ -1786,9 +2142,21 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // // - NOTE(b/141991141): Some filesystems may not ensure coherence // between multiple handles for the same file. - openReadable := !d.readFile.isNil() || read - openWritable := !d.writeFile.isNil() || write - h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc) + var ( + openReadable bool + openWritable bool + h handle + err error + ) + if d.fs.opts.lisaEnabled { + openReadable = d.readFDLisa.Ok() || read + openWritable = d.writeFDLisa.Ok() || write + h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc) + } else { + openReadable = !d.readFile.isNil() || read + openWritable = !d.writeFile.isNil() || write + h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc) + } if linuxerr.Equals(linuxerr.EACCES, err) && (openReadable != read || openWritable != write) { // It may not be possible to use a single handle for both // reading and writing, since permissions on the file may have @@ -1798,7 +2166,11 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool ctx.Debugf("gofer.dentry.ensureSharedHandle: bifurcating read/write handles for dentry %p", d) openReadable = read openWritable = write - h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc) + if d.fs.opts.lisaEnabled { + h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc) + } else { + h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc) + } } if err != nil { d.handleMu.Unlock() @@ -1860,9 +2232,16 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // previously opened for reading (without an FD), then existing // translations of the file may use the internal page cache; // invalidate those mappings. - if d.writeFile.isNil() { - invalidateTranslations = !d.readFile.isNil() - atomic.StoreInt32(&d.mmapFD, h.fd) + if d.fs.opts.lisaEnabled { + if !d.writeFDLisa.Ok() { + invalidateTranslations = d.readFDLisa.Ok() + atomic.StoreInt32(&d.mmapFD, h.fd) + } + } else { + if d.writeFile.isNil() { + invalidateTranslations = !d.readFile.isNil() + atomic.StoreInt32(&d.mmapFD, h.fd) + } } } else if openWritable && d.writeFD < 0 { atomic.StoreInt32(&d.writeFD, h.fd) @@ -1889,24 +2268,45 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool atomic.StoreInt32(&d.mmapFD, -1) } - // Switch to new fids. - var oldReadFile p9file - if openReadable { - oldReadFile = d.readFile - d.readFile = h.file - } - var oldWriteFile p9file - if openWritable { - oldWriteFile = d.writeFile - d.writeFile = h.file - } - // NOTE(b/141991141): Clunk old fids before making new fids visible (by - // unlocking d.handleMu). - if !oldReadFile.isNil() { - oldReadFile.close(ctx) - } - if !oldWriteFile.isNil() && oldReadFile != oldWriteFile { - oldWriteFile.close(ctx) + // Switch to new fids/FDs. + if d.fs.opts.lisaEnabled { + oldReadFD := lisafs.InvalidFDID + if openReadable { + oldReadFD = d.readFDLisa.ID() + d.readFDLisa = h.fdLisa + } + oldWriteFD := lisafs.InvalidFDID + if openWritable { + oldWriteFD = d.writeFDLisa.ID() + d.writeFDLisa = h.fdLisa + } + // NOTE(b/141991141): Close old FDs before making new fids visible (by + // unlocking d.handleMu). + if oldReadFD.Ok() { + d.fs.clientLisa.CloseFDBatched(ctx, oldReadFD) + } + if oldWriteFD.Ok() && oldReadFD != oldWriteFD { + d.fs.clientLisa.CloseFDBatched(ctx, oldWriteFD) + } + } else { + var oldReadFile p9file + if openReadable { + oldReadFile = d.readFile + d.readFile = h.file + } + var oldWriteFile p9file + if openWritable { + oldWriteFile = d.writeFile + d.writeFile = h.file + } + // NOTE(b/141991141): Clunk old fids before making new fids visible (by + // unlocking d.handleMu). + if !oldReadFile.isNil() { + oldReadFile.close(ctx) + } + if !oldWriteFile.isNil() && oldReadFile != oldWriteFile { + oldWriteFile.close(ctx) + } } } d.handleMu.Unlock() @@ -1930,27 +2330,29 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // Preconditions: d.handleMu must be locked. func (d *dentry) readHandleLocked() handle { return handle{ - file: d.readFile, - fd: d.readFD, + fdLisa: d.readFDLisa, + file: d.readFile, + fd: d.readFD, } } // Preconditions: d.handleMu must be locked. func (d *dentry) writeHandleLocked() handle { return handle{ - file: d.writeFile, - fd: d.writeFD, + fdLisa: d.writeFDLisa, + file: d.writeFile, + fd: d.writeFD, } } func (d *dentry) syncRemoteFile(ctx context.Context) error { d.handleMu.RLock() defer d.handleMu.RUnlock() - return d.syncRemoteFileLocked(ctx) + return d.syncRemoteFileLocked(ctx, nil /* accFsyncFDIDsLisa */) } // Preconditions: d.handleMu must be locked. -func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { +func (d *dentry) syncRemoteFileLocked(ctx context.Context, accFsyncFDIDsLisa *[]lisafs.FDID) error { // If we have a host FD, fsyncing it is likely to be faster than an fsync // RPC. Prefer syncing write handles over read handles, since some remote // filesystem implementations may not sync changes made through write @@ -1961,7 +2363,13 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { ctx.UninterruptibleSleepFinish(false) return err } - if !d.writeFile.isNil() { + if d.fs.opts.lisaEnabled && d.writeFDLisa.Ok() { + if accFsyncFDIDsLisa != nil { + *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.writeFDLisa.ID()) + return nil + } + return d.writeFDLisa.Sync(ctx) + } else if !d.fs.opts.lisaEnabled && !d.writeFile.isNil() { return d.writeFile.fsync(ctx) } if d.readFD >= 0 { @@ -1970,13 +2378,19 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { ctx.UninterruptibleSleepFinish(false) return err } - if !d.readFile.isNil() { + if d.fs.opts.lisaEnabled && d.readFDLisa.Ok() { + if accFsyncFDIDsLisa != nil { + *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.readFDLisa.ID()) + return nil + } + return d.readFDLisa.Sync(ctx) + } else if !d.fs.opts.lisaEnabled && !d.readFile.isNil() { return d.readFile.fsync(ctx) } return nil } -func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error { +func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error { d.handleMu.RLock() defer d.handleMu.RUnlock() h := d.writeHandleLocked() @@ -1989,7 +2403,7 @@ func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) err return err } } - if err := d.syncRemoteFileLocked(ctx); err != nil { + if err := d.syncRemoteFileLocked(ctx, accFsyncFDIDsLisa); err != nil { if !forFilesystemSync { return err } @@ -2046,18 +2460,33 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu d := fd.dentry() const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME) if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC { - // Use specialFileFD.handle.file for the getattr if available, for the - // same reason that we try to use open file handles in - // dentry.updateFromGetattrLocked(). - var file p9file - if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok { - file = sffd.handle.file - } - d.metadataMu.Lock() - err := d.updateFromGetattrLocked(ctx, file) - d.metadataMu.Unlock() - if err != nil { - return linux.Statx{}, err + if d.fs.opts.lisaEnabled { + // Use specialFileFD.handle.fileLisa for the Stat if available, for the + // same reason that we try to use open FD in updateFromStatLisaLocked(). + var fdLisa *lisafs.ClientFD + if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok { + fdLisa = &sffd.handle.fdLisa + } + d.metadataMu.Lock() + err := d.updateFromStatLisaLocked(ctx, fdLisa) + d.metadataMu.Unlock() + if err != nil { + return linux.Statx{}, err + } + } else { + // Use specialFileFD.handle.file for the getattr if available, for the + // same reason that we try to use open file handles in + // dentry.updateFromGetattrLocked(). + var file p9file + if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok { + file = sffd.handle.file + } + d.metadataMu.Lock() + err := d.updateFromGetattrLocked(ctx, file) + d.metadataMu.Unlock() + if err != nil { + return linux.Statx{}, err + } } } var stat linux.Statx @@ -2078,7 +2507,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) // ListXattr implements vfs.FileDescriptionImpl.ListXattr. func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { - return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size) + return fd.dentry().listXattr(ctx, size) } // GetXattr implements vfs.FileDescriptionImpl.GetXattr. diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index 806392d50..d5cc73f33 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -33,6 +33,7 @@ func TestDestroyIdempotent(t *testing.T) { }, syncableDentries: make(map[*dentry]struct{}), inoByQIDPath: make(map[uint64]uint64), + inoByKey: make(map[inoKey]uint64), } attr := &p9.Attr{ diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go index 02540a754..394aecd62 100644 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ b/pkg/sentry/fsimpl/gofer/handle.go @@ -17,6 +17,7 @@ package gofer import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostfd" @@ -26,10 +27,13 @@ import ( // handle represents a remote "open file descriptor", consisting of an opened // fid (p9.File) and optionally a host file descriptor. // +// If lisafs is being used, fdLisa points to an open file on the server. +// // These are explicitly not savable. type handle struct { - file p9file - fd int32 // -1 if unavailable + fdLisa lisafs.ClientFD + file p9file + fd int32 // -1 if unavailable } // Preconditions: read || write. @@ -65,13 +69,47 @@ func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (hand }, nil } +// Preconditions: read || write. +func openHandleLisa(ctx context.Context, fdLisa lisafs.ClientFD, read, write, trunc bool) (handle, error) { + var flags uint32 + switch { + case read && write: + flags = unix.O_RDWR + case read: + flags = unix.O_RDONLY + case write: + flags = unix.O_WRONLY + default: + panic("tried to open unreadable and unwritable handle") + } + if trunc { + flags |= unix.O_TRUNC + } + openFD, hostFD, err := fdLisa.OpenAt(ctx, flags) + if err != nil { + return handle{fd: -1}, err + } + h := handle{ + fdLisa: fdLisa.Client().NewFD(openFD), + fd: int32(hostFD), + } + return h, nil +} + func (h *handle) isOpen() bool { + if h.fdLisa.Client() != nil { + return h.fdLisa.Ok() + } return !h.file.isNil() } func (h *handle) close(ctx context.Context) { - h.file.close(ctx) - h.file = p9file{} + if h.fdLisa.Client() != nil { + h.fdLisa.CloseBatched(ctx) + } else { + h.file.close(ctx) + h.file = p9file{} + } if h.fd >= 0 { unix.Close(int(h.fd)) h.fd = -1 @@ -89,19 +127,27 @@ func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offs return n, err } if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() { - n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset) - return uint64(n), err + if h.fdLisa.Client() != nil { + return h.fdLisa.Read(ctx, dsts.Head().ToSlice(), offset) + } + return h.file.readAt(ctx, dsts.Head().ToSlice(), offset) } // Buffer the read since p9.File.ReadAt() takes []byte. buf := make([]byte, dsts.NumBytes()) - n, err := h.file.readAt(ctx, buf, offset) + var n uint64 + var err error + if h.fdLisa.Client() != nil { + n, err = h.fdLisa.Read(ctx, buf, offset) + } else { + n, err = h.file.readAt(ctx, buf, offset) + } if n == 0 { return 0, err } if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil { return cp, cperr } - return uint64(n), err + return n, err } func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { @@ -115,8 +161,10 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o return n, err } if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() { - n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset) - return uint64(n), err + if h.fdLisa.Client() != nil { + return h.fdLisa.Write(ctx, srcs.Head().ToSlice(), offset) + } + return h.file.writeAt(ctx, srcs.Head().ToSlice(), offset) } // Buffer the write since p9.File.WriteAt() takes []byte. buf := make([]byte, srcs.NumBytes()) @@ -124,12 +172,18 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o if cp == 0 { return 0, cperr } - n, err := h.file.writeAt(ctx, buf[:cp], offset) + var n uint64 + var err error + if h.fdLisa.Client() != nil { + n, err = h.fdLisa.Write(ctx, buf[:cp], offset) + } else { + n, err = h.file.writeAt(ctx, buf[:cp], offset) + } // err takes precedence over cperr. if err != nil { - return uint64(n), err + return n, err } - return uint64(n), cperr + return n, cperr } type handleReadWriter struct { diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 5a3ddfc9d..0d97b60fd 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -141,18 +141,18 @@ func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, u return fdobj, qid, iounit, err } -func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) { +func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (uint64, error) { ctx.UninterruptibleSleepStart(false) n, err := f.file.ReadAt(p, offset) ctx.UninterruptibleSleepFinish(false) - return n, err + return uint64(n), err } -func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) { +func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (uint64, error) { ctx.UninterruptibleSleepStart(false) n, err := f.file.WriteAt(p, offset) ctx.UninterruptibleSleepFinish(false) - return n, err + return uint64(n), err } func (f p9file) fsync(ctx context.Context) error { diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 947dbe05f..874f9873d 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -98,6 +98,12 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { } d.handleMu.RLock() defer d.handleMu.RUnlock() + if d.fs.opts.lisaEnabled { + if !d.writeFDLisa.Ok() { + return nil + } + return d.writeFDLisa.Flush(ctx) + } if d.writeFile.isNil() { return nil } @@ -110,6 +116,9 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint return d.doAllocate(ctx, offset, length, func() error { d.handleMu.RLock() defer d.handleMu.RUnlock() + if d.fs.opts.lisaEnabled { + return d.writeFDLisa.Allocate(ctx, mode, offset, length) + } return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) }) } @@ -282,8 +291,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off // changes to the host. if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode { atomic.StoreUint32(&d.mode, newMode) - if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil { - return 0, offset, err + if d.fs.opts.lisaEnabled { + stat := linux.Statx{Mask: linux.STATX_MODE, Mode: uint16(newMode)} + failureMask, failureErr, err := d.controlFDLisa.SetStat(ctx, &stat) + if err != nil { + return 0, offset, err + } + if failureMask != 0 { + return 0, offset, failureErr + } + } else { + if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil { + return 0, offset, err + } } } } @@ -677,7 +697,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */) + return fd.dentry().syncCachedFile(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go index 226790a11..5d4009832 100644 --- a/pkg/sentry/fsimpl/gofer/revalidate.go +++ b/pkg/sentry/fsimpl/gofer/revalidate.go @@ -15,7 +15,9 @@ package gofer import ( + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" ) @@ -234,28 +236,54 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF } // Lock metadata on all dentries *before* getting attributes for them. state.lockAllMetadata() - stats, err := state.start.file.multiGetAttr(ctx, state.names) - if err != nil { - return err + + var ( + stats []p9.FullStat + statsLisa []linux.Statx + numStats int + ) + if fs.opts.lisaEnabled { + var err error + statsLisa, err = state.start.controlFDLisa.WalkStat(ctx, state.names) + if err != nil { + return err + } + numStats = len(statsLisa) + } else { + var err error + stats, err = state.start.file.multiGetAttr(ctx, state.names) + if err != nil { + return err + } + numStats = len(stats) } i := -1 for d := state.popFront(); d != nil; d = state.popFront() { i++ - found := i < len(stats) + found := i < numStats if i == 0 && len(state.names[0]) == 0 { if found && !d.isSynthetic() { // First dentry is where the search is starting, just update attributes // since it cannot be replaced. - d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata. + if fs.opts.lisaEnabled { + d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: acquired by lockAllMetadata. + } else { + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata. + } } d.metadataMu.Unlock() // +checklocksforce: see above. continue } - // Note that synthetic dentries will always fails the comparison check - // below. - if !found || d.qidPath != stats[i].QID.Path { + // Note that synthetic dentries will always fail this comparison check. + var shouldInvalidate bool + if fs.opts.lisaEnabled { + shouldInvalidate = !found || d.inoKey != inoKeyFromStat(&statsLisa[i]) + } else { + shouldInvalidate = !found || d.qidPath != stats[i].QID.Path + } + if shouldInvalidate { d.metadataMu.Unlock() // +checklocksforce: see above. if !found && d.isSynthetic() { // We have a synthetic file, and no remote file has arisen to replace @@ -298,7 +326,11 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF } // The file at this path hasn't changed. Just update cached metadata. - d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above. + if fs.opts.lisaEnabled { + d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: see above. + } else { + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above. + } d.metadataMu.Unlock() } diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go index 8dcbc61ed..82878c056 100644 --- a/pkg/sentry/fsimpl/gofer/save_restore.go +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/safemem" @@ -112,10 +113,19 @@ func (d *dentry) prepareSaveRecursive(ctx context.Context) error { return err } } - if !d.readFile.isNil() || !d.writeFile.isNil() { - d.fs.savedDentryRW[d] = savedDentryRW{ - read: !d.readFile.isNil(), - write: !d.writeFile.isNil(), + if d.fs.opts.lisaEnabled { + if d.readFDLisa.Ok() || d.writeFDLisa.Ok() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: d.readFDLisa.Ok(), + write: d.writeFDLisa.Ok(), + } + } + } else { + if !d.readFile.isNil() || !d.writeFile.isNil() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: !d.readFile.isNil(), + write: !d.writeFile.isNil(), + } } } d.dirMu.Lock() @@ -177,25 +187,37 @@ func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRest return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID) } fs.opts.fd = fd - if err := fs.dial(ctx); err != nil { - return err - } fs.inoByQIDPath = make(map[uint64]uint64) + fs.inoByKey = make(map[inoKey]uint64) - // Restore the filesystem root. - ctx.UninterruptibleSleepStart(false) - attached, err := fs.client.Attach(fs.opts.aname) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - return err - } - attachFile := p9file{attached} - qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) - if err != nil { - return err - } - if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { - return err + if fs.opts.lisaEnabled { + rootInode, err := fs.initClientLisa(ctx) + if err != nil { + return err + } + if err := fs.root.restoreFileLisa(ctx, rootInode, &opts); err != nil { + return err + } + } else { + if err := fs.dial(ctx); err != nil { + return err + } + + // Restore the filesystem root. + ctx.UninterruptibleSleepStart(false) + attached, err := fs.client.Attach(fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return err + } + attachFile := p9file{attached} + qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) + if err != nil { + return err + } + if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { + return err + } } // Restore remaining dentries. @@ -255,18 +277,18 @@ func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrM if d.isRegularFile() { if opts.ValidateFileSizes { if !attrMask.Size { - return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d)) + return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))} } if d.size != attr.Size { - return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size) + return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size)} } } if opts.ValidateFileModificationTimestamps { if !attrMask.MTime { - return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d)) + return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))} } if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want { - return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want)) + return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))} } } } @@ -283,6 +305,55 @@ func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrM return nil } +func (d *dentry) restoreFileLisa(ctx context.Context, inode *lisafs.Inode, opts *vfs.CompleteRestoreOptions) error { + d.controlFDLisa = d.fs.clientLisa.NewFD(inode.ControlFD) + + // Gofers do not preserve inoKey across checkpoint/restore, so: + // + // - We must assume that the remote filesystem did not change in a way that + // would invalidate dentries, since we can't revalidate dentries by + // checking inoKey. + // + // - We need to associate the new inoKey with the existing d.ino. + d.inoKey = inoKeyFromStat(&inode.Stat) + d.fs.inoMu.Lock() + d.fs.inoByKey[d.inoKey] = d.ino + d.fs.inoMu.Unlock() + + // Check metadata stability before updating metadata. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + if d.isRegularFile() { + if opts.ValidateFileSizes { + if inode.Stat.Mask&linux.STATX_SIZE != 0 { + return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))} + } + if d.size != inode.Stat.Size { + return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, inode.Stat.Size)} + } + } + if opts.ValidateFileModificationTimestamps { + if inode.Stat.Mask&linux.STATX_MTIME != 0 { + return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))} + } + if want := dentryTimestampFromLisa(inode.Stat.Mtime); d.mtime != want { + return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))} + } + } + } + if !d.cachedMetadataAuthoritative() { + d.updateFromLisaStatLocked(&inode.Stat) + } + + if rw, ok := d.fs.savedDentryRW[d]; ok { + if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil { + return err + } + } + + return nil +} + // Preconditions: d is not synthetic. func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { for _, child := range d.children { @@ -305,19 +376,35 @@ func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.Comp // only be detected by checking filesystem.syncableDentries). d.parent has been // restored. func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { - qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) - if err != nil { - return err - } - if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { - return err + if d.fs.opts.lisaEnabled { + inode, err := d.parent.controlFDLisa.Walk(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFileLisa(ctx, inode, opts); err != nil { + return err + } + } else { + qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { + return err + } } return d.restoreDescendantsRecursive(ctx, opts) } func (fd *specialFileFD) completeRestore(ctx context.Context) error { d := fd.dentry() - h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + var h handle + var err error + if d.fs.opts.lisaEnabled { + h, err = openHandleLisa(ctx, d.controlFDLisa, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + } else { + h, err = openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + } if err != nil { return err } diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index fe15f8583..86ab70453 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -59,11 +59,6 @@ func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) { // BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect. func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error { - cf, ok := sockTypeToP9(ce.Type()) - if !ok { - return syserr.ErrConnectionRefused - } - // No lock ordering required as only the ConnectingEndpoint has a mutex. ce.Lock() @@ -77,7 +72,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec return syserr.ErrInvalidEndpointState } - c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue()) + c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue()) if err != nil { ce.Unlock() return err @@ -95,7 +90,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec // UnidirectionalConnect implements // transport.BoundEndpoint.UnidirectionalConnect. func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) { - c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{}) + c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{}) if err != nil { return nil, err } @@ -111,25 +106,39 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect return c, nil } -func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { - hostFile, err := e.dentry.file.connect(ctx, flags) - if err != nil { +func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { + if e.dentry.fs.opts.lisaEnabled { + hostSockFD, err := e.dentry.controlFDLisa.Connect(ctx, sockType) + if err != nil { + return nil, syserr.ErrConnectionRefused + } + + c, serr := host.NewSCMEndpoint(ctx, hostSockFD, queue, e.path) + if serr != nil { + unix.Close(hostSockFD) + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr) + return nil, serr + } + return c, nil + } + + flags, ok := sockTypeToP9(sockType) + if !ok { return nil, syserr.ErrConnectionRefused } - // Dup the fd so that the new endpoint can manage its lifetime. - hostFD, err := unix.Dup(hostFile.FD()) + hostFile, err := e.dentry.file.connect(ctx, flags) if err != nil { - log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err) - return nil, syserr.FromError(err) + return nil, syserr.ErrConnectionRefused } - // After duplicating, we no longer need hostFile. - hostFile.Close() - c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path) + c, serr := host.NewSCMEndpoint(ctx, hostFile.FD(), queue, e.path) if serr != nil { - log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr) + hostFile.Close() + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr) return nil, serr } + // Ownership has been transferred to c. + hostFile.Release() return c, nil } diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index a8d47b65b..c568bbfd2 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" @@ -149,6 +150,9 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error { if !fd.vfsfd.IsWritable() { return nil } + if fs := fd.filesystem(); fs.opts.lisaEnabled { + return fd.handle.fdLisa.Flush(ctx) + } return fd.handle.file.flush(ctx) } @@ -184,6 +188,9 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint if fd.isRegularFile { d := fd.dentry() return d.doAllocate(ctx, offset, length, func() error { + if d.fs.opts.lisaEnabled { + return fd.handle.fdLisa.Allocate(ctx, mode, offset, length) + } return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) }) } @@ -371,10 +378,10 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - return fd.sync(ctx, false /* forFilesystemSync */) + return fd.sync(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */) } -func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error { +func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error { // Locks to ensure it didn't race with fd.Release(). fd.releaseMu.RLock() defer fd.releaseMu.RUnlock() @@ -391,6 +398,13 @@ func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error ctx.UninterruptibleSleepFinish(false) return err } + if fs := fd.filesystem(); fs.opts.lisaEnabled { + if accFsyncFDIDsLisa != nil { + *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, fd.handle.fdLisa.ID()) + return nil + } + return fd.handle.fdLisa.Sync(ctx) + } return fd.handle.file.fsync(ctx) }() if err != nil { diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go index dbd834c67..27d9be5c4 100644 --- a/pkg/sentry/fsimpl/gofer/symlink.go +++ b/pkg/sentry/fsimpl/gofer/symlink.go @@ -35,7 +35,13 @@ func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { return target, nil } } - target, err := d.file.readlink(ctx) + var target string + var err error + if d.fs.opts.lisaEnabled { + target, err = d.controlFDLisa.ReadLinkAt(ctx) + } else { + target, err = d.file.readlink(ctx) + } if d.fs.opts.interop != InteropModeShared { if err == nil { d.haveTarget = true diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 9cbe805b9..07940b225 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -17,6 +17,7 @@ package gofer import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -24,6 +25,10 @@ func dentryTimestampFromP9(s, ns uint64) int64 { return int64(s*1e9 + ns) } +func dentryTimestampFromLisa(t linux.StatxTimestamp) int64 { + return t.Sec*1e9 + int64(t.Nsec) +} + // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime || mnt.ReadOnly() { diff --git a/pkg/sentry/fsimpl/mqfs/BUILD b/pkg/sentry/fsimpl/mqfs/BUILD index e1a38686b..332c9b504 100644 --- a/pkg/sentry/fsimpl/mqfs/BUILD +++ b/pkg/sentry/fsimpl/mqfs/BUILD @@ -18,9 +18,9 @@ go_library( name = "mqfs", srcs = [ "mqfs.go", - "root.go", "queue.go", "registry.go", + "root.go", "root_inode_refs.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/fsimpl/mqfs/mqfs.go b/pkg/sentry/fsimpl/mqfs/mqfs.go index ed559cd13..c2b53c9d0 100644 --- a/pkg/sentry/fsimpl/mqfs/mqfs.go +++ b/pkg/sentry/fsimpl/mqfs/mqfs.go @@ -30,6 +30,7 @@ import ( ) const ( + // Name is the user-visible filesystem name. Name = "mqueue" defaultMaxCachedDentries = uint64(1000) ) @@ -73,7 +74,7 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF } impl.fs.MaxCachedDentries = maxCachedDentries - impl.root.IncRef() + impl.fs.VFSFilesystem().IncRef() return impl.fs.VFSFilesystem(), impl.root.VFSDentry(), nil } @@ -109,7 +110,6 @@ type filesystem struct { func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) fs.Filesystem.Release(ctx) - fs.root.DecRef(ctx) } // MountOptions implements vfs.FilesystemImpl.MountOptions. diff --git a/pkg/sentry/fsimpl/mqfs/registry.go b/pkg/sentry/fsimpl/mqfs/registry.go index 2c9c79f01..c8fbe4d33 100644 --- a/pkg/sentry/fsimpl/mqfs/registry.go +++ b/pkg/sentry/fsimpl/mqfs/registry.go @@ -63,11 +63,12 @@ func NewRegistryImpl(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds * root: &dentry, } fs.VFSFilesystem().Init(vfsObj, &FilesystemType{}, fs) + vfsfs := fs.VFSFilesystem() dentry.InitRoot(&fs.Filesystem, fs.newRootInode(ctx, creds)) - dentry.IncRef() + defer vfsfs.DecRef(ctx) // NewDisconnectedMount will obtain a ref on success. - mount, err := vfsObj.NewDisconnectedMount(fs.VFSFilesystem(), dentry.VFSDentry(), &vfs.MountOptions{}) + mount, err := vfsObj.NewDisconnectedMount(vfsfs, dentry.VFSDentry(), &vfs.MountOptions{}) if err != nil { return nil, err } @@ -129,6 +130,7 @@ func (r *RegistryImpl) Unlink(ctx context.Context, name string) error { // Destroy implements mq.RegistryImpl.Destroy. func (r *RegistryImpl) Destroy(ctx context.Context) { r.root.DecRef(ctx) + r.mount.DecRef(ctx) } // lookup retreives a kernfs.Inode using a name. diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 3b3dcf836..044902241 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -86,7 +86,7 @@ func putDentrySlice(ds *[]*dentry) { // fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. // -// +checklocksrelease:fs.renameMu +// +checklocksreleaseread:fs.renameMu func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) { fs.renameMu.RUnlock() if *dsp == nil { diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 26d44744b..7b0be9c14 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -268,6 +268,6 @@ func cpuInfoData(k *kernel.Kernel) string { return buf.String() } -func shmData(v uint64) dynamicInode { +func ipcData(v uint64) dynamicInode { return newStaticFile(strconv.FormatUint(v, 10)) } diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 4d3a2f7e6..faec36d8d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -262,9 +262,8 @@ var _ dynamicInode = (*meminfoData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { - k := kernel.KernelFromContext(ctx) - mf := k.MemoryFile() - mf.UpdateUsage() + mf := kernel.KernelFromContext(ctx).MemoryFile() + _ = mf.UpdateUsage() // Best effort snapshot, totalUsage := usage.MemoryAccounting.Copy() totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage) anon := snapshot.Anonymous + snapshot.Tmpfs diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 99f64a9d8..82e2857b3 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -47,9 +47,12 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k * "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), - "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), - "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), - "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), + "shmall": fs.newInode(ctx, root, 0444, ipcData(linux.SHMALL)), + "shmmax": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMAX)), + "shmmni": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMNI)), + "msgmni": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNI)), + "msgmax": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMAX)), + "msgmnb": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNB)), "yama": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "ptrace_scope": fs.newYAMAPtraceScopeFile(ctx, k, root), }), diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index f322d2747..7fcb2d26b 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -84,6 +84,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + k := kernel.KernelFromContext(ctx) + fsDirChildren := make(map[string]kernfs.Inode) + // Create an empty directory to serve as the mount point for cgroupfs when + // cgroups are available. This emulates Linux behaviour, see + // kernel/cgroup.c:cgroup_init(). Note that in Linux, userspace (typically + // the init process) is ultimately responsible for actually mounting + // cgroupfs, but the kernel creates the mountpoint. For the sentry, the + // launcher mounts cgroupfs. + if k.CgroupRegistry() != nil { + fsDirChildren["cgroup"] = fs.newDir(ctx, creds, defaultSysDirMode, nil) + } + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), @@ -97,7 +109,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt }), }), "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), - "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, fsDirChildren), "kernel": kernelDir(ctx, fs, creds), "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 0a0d914cc..0c46a3a13 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -87,3 +87,17 @@ func TestSysRootContainsExpectedEntries(t *testing.T) { "power": linux.DT_DIR, }) } + +func TestCgroupMountpointExists(t *testing.T) { + // Note: The mountpoint is only created if cgroups are available. This is + // the VFS2 implementation of sysfs and the test runs with VFS2 enabled, so + // we expect to see the mount point unconditionally. + s := newTestSystem(t) + defer s.Destroy() + pop := s.PathOpAtRoot("/fs") + s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ + "cgroup": linux.DT_DIR, + }) + pop = s.PathOpAtRoot("/fs/cgroup") + s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ /*empty*/ }) +} diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 0f2ac6144..453e1aa61 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -95,7 +95,7 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, parentDir *directory) *inode { file := ®ularFile{ memFile: fs.mfp.MemoryFile(), - memoryUsageKind: usage.Tmpfs, + memoryUsageKind: fs.usage, seals: linux.F_SEAL_SEAL, } file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode, parentDir) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index feafb06e4..f84165aba 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -41,6 +41,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr" "gvisor.dev/gvisor/pkg/sync" @@ -74,6 +75,10 @@ type filesystem struct { // filesystem. Immutable. mopts string + // usage is the memory accounting category under which pages backing + // files in this filesystem are accounted. + usage usage.MemoryKind + // mu serializes changes to the Dentry tree. mu sync.RWMutex `state:"nosave"` @@ -106,6 +111,10 @@ type FilesystemOpts struct { // tmpfs filesystem. This allows tmpfs to "impersonate" other // filesystems, like ramdiskfs and cgroupfs. FilesystemType vfs.FilesystemType + + // Usage is the memory accounting category under which pages backing files in + // the filesystem are accounted. + Usage *usage.MemoryKind } // GetFilesystem implements vfs.FilesystemType.GetFilesystem. @@ -184,11 +193,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } clock := time.RealtimeClockFromContext(ctx) + memUsage := usage.Tmpfs + if tmpfsOpts.Usage != nil { + memUsage = *tmpfsOpts.Usage + } fs := filesystem{ mfp: mfp, clock: clock, devMinor: devMinor, mopts: opts.Data, + usage: memUsage, } fs.vfsfs.Init(vfsObj, newFSType, &fs) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 52d47994d..8b059aa7d 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -74,7 +74,7 @@ func putDentrySlice(ds *[]*dentry) { // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -// +checklocksrelease:fs.renameMu +// +checklocksreleaseread:fs.renameMu func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { fs.renameMu.RUnlock() if *ds == nil { diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index 66fa1ad40..03c8e2f38 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -12,8 +12,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/fd", - "//pkg/hostarch", + "//pkg/eventfd", "//pkg/log", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go index 285ea9050..5df06a60f 100644 --- a/pkg/sentry/hostmm/hostmm.go +++ b/pkg/sentry/hostmm/hostmm.go @@ -21,9 +21,7 @@ import ( "os" "path" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/log" ) @@ -54,7 +52,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) } defer eventControlFile.Close() - eventFD, err := newEventFD() + eventFD, err := eventfd.Create() if err != nil { return nil, err } @@ -75,20 +73,11 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) const stopVal = 1 << 63 stopCh := make(chan struct{}) go func() { // S/R-SAFE: f provides synchronization if necessary - rw := fd.NewReadWriter(eventFD.FD()) - var buf [sizeofUint64]byte for { - n, err := rw.Read(buf[:]) + val, err := eventFD.Read() if err != nil { - if err == unix.EINTR { - continue - } panic(fmt.Sprintf("failed to read from memory pressure level eventfd: %v", err)) } - if n != sizeofUint64 { - panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64)) - } - val := hostarch.ByteOrder.Uint64(buf[:]) if val >= stopVal { // Assume this was due to the notifier's "destructor" (the // function returned by NotifyCurrentMemcgPressureCallback @@ -101,30 +90,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) } }() return func() { - rw := fd.NewReadWriter(eventFD.FD()) - var buf [sizeofUint64]byte - hostarch.ByteOrder.PutUint64(buf[:], stopVal) - for { - n, err := rw.Write(buf[:]) - if err != nil { - if err == unix.EINTR { - continue - } - panic(fmt.Sprintf("failed to write to memory pressure level eventfd: %v", err)) - } - if n != sizeofUint64 { - panic(fmt.Sprintf("short write to memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64)) - } - break - } + eventFD.Write(stopVal) <-stopCh }, nil } - -func newEventFD() (*fd.FD, error) { - f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0) - if e != 0 { - return nil, fmt.Errorf("failed to create eventfd: %v", e) - } - return fd.New(int(f)), nil -} diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 5bba9de0b..2363cec5f 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -1,13 +1,26 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) +go_template_instance( + name = "atomicptr_netns", + out = "atomicptr_netns_unsafe.go", + package = "inet", + prefix = "Namespace", + template = "//pkg/sync/atomicptr:generic_atomicptr", + types = { + "Value": "Namespace", + }, +) + go_library( name = "inet", srcs = [ + "atomicptr_netns_unsafe.go", "context.go", "inet.go", "namespace.go", diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 9f30a7706..f3f16eb7a 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -216,7 +216,6 @@ go_library( visibility = ["//:sandbox"], deps = [ ":uncaught_signal_go_proto", - "//pkg/sentry/kernel/ipc", "//pkg/abi", "//pkg/abi/linux", "//pkg/abi/linux/errno", @@ -257,8 +256,8 @@ go_library( "//pkg/sentry/hostcpu", "//pkg/sentry/inet", "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/epoll", "//pkg/sentry/kernel/futex", + "//pkg/sentry/kernel/ipc", "//pkg/sentry/kernel/mq", "//pkg/sentry/kernel/msgqueue", "//pkg/sentry/kernel/sched", diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 6006c46a9..8d0a21baf 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -66,7 +66,7 @@ type pollEntry struct { file *refs.WeakRef `state:"manual"` id FileIdentifier `state:"wait"` userData [2]int32 - waiter waiter.Entry `state:"manual"` + waiter waiter.Entry mask waiter.EventMask flags EntryFlags @@ -102,7 +102,7 @@ type EventPoll struct { // Wait queue is used to notify interested parties when the event poll // object itself becomes readable or writable. - waiter.Queue `state:"zerovalue"` + waiter.Queue // files is the map of all the files currently being observed, it is // protected by mu. @@ -454,14 +454,3 @@ func (e *EventPoll) RemoveEntry(ctx context.Context, id FileIdentifier) error { return nil } - -// UnregisterEpollWaiters removes the epoll waiter objects from the waiting -// queues. This is different from Release() as the file is not dereferenced. -func (e *EventPoll) UnregisterEpollWaiters() { - e.mu.Lock() - defer e.mu.Unlock() - - for _, entry := range e.files { - entry.id.File.EventUnregister(&entry.waiter) - } -} diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go index e08d6287f..135a6d72c 100644 --- a/pkg/sentry/kernel/epoll/epoll_state.go +++ b/pkg/sentry/kernel/epoll/epoll_state.go @@ -21,9 +21,7 @@ import ( // afterLoad is invoked by stateify. func (p *pollEntry) afterLoad() { - p.waiter.Callback = p p.file = refs.NewWeakRef(p.id.File, p) - p.id.File.EventRegister(&p.waiter, p.mask) } // afterLoad is invoked by stateify. diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 5ea44a2c2..bf625dede 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -54,7 +54,7 @@ type EventOperations struct { // Queue is used to notify interested parties when the event object // becomes readable or writable. - wq waiter.Queue `state:"zerovalue"` + wq waiter.Queue // val is the current value of the event counter. val uint64 diff --git a/pkg/sentry/kernel/ipc/BUILD b/pkg/sentry/kernel/ipc/BUILD index a5cbb2b51..bb5cf1c17 100644 --- a/pkg/sentry/kernel/ipc/BUILD +++ b/pkg/sentry/kernel/ipc/BUILD @@ -5,9 +5,9 @@ package(licenses = ["notice"]) go_library( name = "ipc", srcs = [ + "ns.go", "object.go", "registry.go", - "ns.go", ], visibility = ["//pkg/sentry:internal"], deps = [ diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 04b24369a..d4851ccda 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -57,7 +57,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/hostcpu" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/ipc" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" @@ -79,11 +78,19 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow -// easy access everywhere. To be removed once VFS2 becomes the default. +// VFS2Enabled is set to true when VFS2 is enabled. Added as a global to allow +// easy access everywhere. +// +// TODO(gvisor.dev/issue/1624): Remove when VFS1 is no longer used. var VFS2Enabled = false -// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow +// LISAFSEnabled is set to true when lisafs protocol is enabled. Added as a +// global to allow easy access everywhere. +// +// TODO(gvisor.dev/issue/6319): Remove when lisafs is default. +var LISAFSEnabled = false + +// FUSEEnabled is set to true when FUSE is enabled. Added as a global to allow // easy access everywhere. To be removed once FUSE is completed. var FUSEEnabled = false @@ -484,11 +491,6 @@ func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error { return err } - // Remove all epoll waiter objects from underlying wait queues. - // NOTE: for programs to resume execution in future snapshot scenarios, - // we will need to re-establish these waiter objects after saving. - k.tasks.unregisterEpollWaiters(ctx) - // Clear the dirent cache before saving because Dirents must be Loaded in a // particular order (parents before children), and Loading dirents from a cache // breaks that order. @@ -621,32 +623,6 @@ func (k *Kernel) flushWritesToFiles(ctx context.Context) error { }) } -// Preconditions: !VFS2Enabled. -func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { - ts.mu.RLock() - defer ts.mu.RUnlock() - - // Tasks that belong to the same process could potentially point to the - // same FDTable. So we retain a map of processed ones to avoid - // processing the same FDTable multiple times. - processed := make(map[*FDTable]struct{}) - for t := range ts.Root.tids { - // We can skip locking Task.mu here since the kernel is paused. - if t.fdTable == nil { - continue - } - if _, ok := processed[t.fdTable]; ok { - continue - } - t.fdTable.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - if e, ok := file.FileOperations.(*epoll.EventPoll); ok { - e.UnregisterEpollWaiters() - } - }) - processed[t.fdTable] = struct{}{} - } -} - // Preconditions: The kernel must be paused. func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { invalidated := make(map[*mm.MemoryManager]struct{}) diff --git a/pkg/sentry/kernel/mq/mq.go b/pkg/sentry/kernel/mq/mq.go index a7c787081..50ca6d34a 100644 --- a/pkg/sentry/kernel/mq/mq.go +++ b/pkg/sentry/kernel/mq/mq.go @@ -40,8 +40,10 @@ const ( ReadWrite ) +// MaxName is the maximum size for a queue name. +const MaxName = 255 + const ( - MaxName = 255 // Maximum size for a queue name. maxPriority = linux.MQ_PRIO_MAX - 1 // Highest possible message priority. maxQueuesDefault = linux.DFLT_QUEUESMAX // Default max number of queues. diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 86beee6fe..8345473f3 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -55,7 +55,7 @@ const ( // // +stateify savable type Pipe struct { - waiter.Queue `state:"nosave"` + waiter.Queue // isNamed indicates whether this is a named pipe. // diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 9a95bf44c..1ea3c1bf7 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -158,7 +158,7 @@ type Task struct { // signalQueue is protected by the signalMutex. Note that the task does // not implement all queue methods, specifically the readiness checks. // The task only broadcast a notification on signal delivery. - signalQueue waiter.Queue `state:"zerovalue"` + signalQueue waiter.Queue // If groupStopPending is true, the task should participate in a group // stop in the interrupt path. @@ -511,9 +511,7 @@ type Task struct { numaNodeMask uint64 // netns is the task's network namespace. netns is never nil. - // - // netns is protected by mu. - netns *inet.Namespace + netns inet.NamespaceAtomicPtr // If rseqPreempted is true, before the next call to p.Switch(), // interrupt rseq critical regions as defined by rseqAddr and diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index e174913d1..69a3227f0 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -447,7 +447,7 @@ func (t *Task) Unshare(flags int32) error { t.mu.Unlock() return linuxerr.EPERM } - t.netns = inet.NewNamespace(t.netns) + t.netns.Store(inet.NewNamespace(t.netns.Load())) } if flags&linux.CLONE_NEWUTS != 0 { if !haveCapSysAdmin { diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index 8de08151a..f0c168ecc 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -191,9 +191,11 @@ const ( // // Preconditions: The task's owning TaskSet.mu must be locked. func (t *Task) updateInfoLocked() { - // Use the task's TID in the root PID namespace for logging. + // Use the task's TID and PID in the root PID namespace for logging. + pid := t.tg.pidns.owner.Root.tgids[t.tg] tid := t.tg.pidns.owner.Root.tids[t] - t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid)) + t.logPrefix.Store(fmt.Sprintf("[% 4d:% 4d] ", pid, tid)) + t.rebuildTraceContext(tid) } @@ -249,5 +251,9 @@ func (t *Task) traceExecEvent(image *TaskImage) { return } defer file.DecRef(t) - trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t)) + + // traceExecEvent function may be called before the task goroutine + // starts, so we must use the async context. + name := file.PathnameWithDeleted(t.AsyncContext()) + trace.Logf(t.traceContext, traceCategory, "exec: %s", name) } diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go index f7711232c..e31e2b2e8 100644 --- a/pkg/sentry/kernel/task_net.go +++ b/pkg/sentry/kernel/task_net.go @@ -20,9 +20,7 @@ import ( // IsNetworkNamespaced returns true if t is in a non-root network namespace. func (t *Task) IsNetworkNamespaced() bool { - t.mu.Lock() - defer t.mu.Unlock() - return !t.netns.IsRoot() + return !t.netns.Load().IsRoot() } // NetworkContext returns the network stack used by the task. NetworkContext @@ -31,14 +29,10 @@ func (t *Task) IsNetworkNamespaced() bool { // TODO(gvisor.dev/issue/1833): Migrate callers of this method to // NetworkNamespace(). func (t *Task) NetworkContext() inet.Stack { - t.mu.Lock() - defer t.mu.Unlock() - return t.netns.Stack() + return t.netns.Load().Stack() } // NetworkNamespace returns the network namespace observed by the task. func (t *Task) NetworkNamespace() *inet.Namespace { - t.mu.Lock() - defer t.mu.Unlock() - return t.netns + return t.netns.Load() } diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 217c6f531..4919dea7c 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -140,7 +140,6 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { allowedCPUMask: cfg.AllowedCPUMask.Copy(), ioUsage: &usage.IO{}, niceness: cfg.Niceness, - netns: cfg.NetworkNamespace, utsns: cfg.UTSNamespace, ipcns: cfg.IPCNamespace, abstractSockets: cfg.AbstractSocketNamespace, @@ -152,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { containerID: cfg.ContainerID, cgroups: make(map[Cgroup]struct{}), } + t.netns.Store(cfg.NetworkNamespace) t.creds.Store(cfg.Credentials) t.endStopCond.L = &t.tg.signalHandlers.mu t.ptraceTracer.Store((*Task)(nil)) diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 77ad62445..e38b723ce 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -324,11 +324,7 @@ type threadGroupNode struct { // eventQueue is notified whenever a event of interest to Task.Wait occurs // in a child of this thread group, or a ptrace tracee of a task in this // thread group. Events are defined in task_exit.go. - // - // Note that we cannot check and save this wait queue similarly to other - // wait queues, as the queue will not be empty by the time of saving, due - // to the wait sourced from Exec(). - eventQueue waiter.Queue `state:"nosave"` + eventQueue waiter.Queue // leader is the thread group's leader, which is the oldest task in the // thread group; usually the last task in the thread group to call diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 9e00c2cec..dc12ad357 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -89,7 +89,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar } // Offset + length must not overflow. if end := opts.Offset + opts.Length; end < opts.Offset { - return 0, linuxerr.ENOMEM + return 0, linuxerr.EOVERFLOW } } else { opts.Offset = 0 diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 8a490b3de..834d72408 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,13 +1,26 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "atomicptr_machine", + out = "atomicptr_machine_unsafe.go", + package = "kvm", + prefix = "machine", + template = "//pkg/sync/atomicptr:generic_atomicptr", + types = { + "Value": "machine", + }, +) + go_library( name = "kvm", srcs = [ "address_space.go", "address_space_amd64.go", "address_space_arm64.go", + "atomicptr_machine_unsafe.go", "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", @@ -50,7 +63,6 @@ go_library( "//pkg/procid", "//pkg/ring0", "//pkg/ring0/pagetables", - "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", @@ -58,6 +70,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/time", + "//pkg/sighandling", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], @@ -69,10 +82,17 @@ go_test( "kvm_amd64_test.go", "kvm_amd64_test.s", "kvm_arm64_test.go", + "kvm_safecopy_test.go", "kvm_test.go", "virtual_map_test.go", ], library = ":kvm", + # FIXME(gvisor.dev/issue/3374): Not working with all build systems. + nogo = False, + # cgo has to be disabled. We have seen libc that blocks all signals and + # calls mmap from pthread_create, but we use SIGSYS to trap mmap system + # calls. + pure = True, tags = [ "manual", "nogotsan", @@ -81,8 +101,10 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/hostarch", + "//pkg/memutil", "//pkg/ring0", "//pkg/ring0/pagetables", + "//pkg/safecopy", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", "//pkg/sentry/platform", diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index bb9967b9f..5be2215ed 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -19,8 +19,8 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/ring0" - "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sighandling" ) // bluepill enters guest mode. @@ -61,6 +61,9 @@ var ( // This is called by bluepillHandler. savedHandler uintptr + // savedSigsysHandler is a pointer to the previos handler of the SIGSYS signals. + savedSigsysHandler uintptr + // dieTrampolineAddr is the address of dieTrampoline. dieTrampolineAddr uintptr ) @@ -94,7 +97,7 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 0567c8d32..b2db2bb9f 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -71,10 +71,6 @@ func (c *vCPU) KernelSyscall() { if regs.Rax != ^uint64(0) { regs.Rip -= 2 // Rewind. } - // We only trigger a bluepill entry in the bluepill function, and can - // therefore be guaranteed that there is no floating point state to be - // loaded on resuming from halt. We only worry about saving on exit. - ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no. // N.B. Since KernelSyscall is called when the kernel makes a syscall, // FS_BASE is already set for correct execution of this function. // @@ -112,8 +108,6 @@ func (c *vCPU) KernelException(vector ring0.Vector) { regs.Rip = 0 } // See above. - ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no. - // See above. ring0.HaltAndWriteFSBase(regs) // escapes: no, reload host segment. } @@ -144,5 +138,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { // Set the context pointer to the saved floating point state. This is // where the guest data has been serialized, the kernel will restore // from this new pointer value. - context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer())) + context.Fpstate = uint64(uintptrValue(c.FloatingPointState().BytePointer())) // escapes: no. } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index c2a1dca11..5d8358f64 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -32,6 +32,8 @@ // This is checked as the source of the fault. #define CLI $0xfa +#define SYS_MMAP 9 + // See bluepill.go. TEXT ·bluepill(SB),NOSPLIT,$0 begin: @@ -95,6 +97,31 @@ TEXT ·addrOfSighandler(SB), $0-8 MOVQ AX, ret+0(FP) RET +TEXT ·sigsysHandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel. + MOVQ $1, CX + CMPL CX, 0x8(SI) + JNE fallback + + MOVL CONTEXT_RAX(DX), CX + CMPL CX, $SYS_MMAP + JNE fallback + PUSHQ DX // First argument (context). + CALL ·seccompMmapHandler(SB) // Call the handler. + POPQ DX // Discard the argument. + RET +fallback: + // Jump to the previous signal handler. + XORQ CX, CX + MOVQ ·savedSigsysHandler(SB), AX + JMP AX + +// func addrOfSighandler() uintptr +TEXT ·addrOfSigsysHandler(SB), $0-8 + MOVQ $·sigsysHandler(SB), AX + MOVQ AX, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 PUSHQ BX // First argument (vCPU). diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index acb0cb05f..df772d620 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -70,7 +70,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { lazyVfp := c.GetLazyVFP() if lazyVfp != 0 { - fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) + fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no context.Fpsimd64.Fpsr = fpsimd.Fpsr context.Fpsimd64.Fpcr = fpsimd.Fpcr context.Fpsimd64.Vregs = fpsimd.Vregs @@ -90,12 +90,12 @@ func (c *vCPU) KernelSyscall() { fpDisableTrap := ring0.CPACREL1() if fpDisableTrap != 0 { - fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) + fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no fpcr := ring0.GetFPCR() fpsr := ring0.GetFPSR() fpsimd.Fpcr = uint32(fpcr) fpsimd.Fpsr = uint32(fpsr) - ring0.SaveVRegs(c.floatingPointState.BytePointer()) + ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no } ring0.Halt() @@ -114,12 +114,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) { fpDisableTrap := ring0.CPACREL1() if fpDisableTrap != 0 { - fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) + fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no fpcr := ring0.GetFPCR() fpsr := ring0.GetFPSR() fpsimd.Fpcr = uint32(fpcr) fpsimd.Fpsr = uint32(fpsr) - ring0.SaveVRegs(c.floatingPointState.BytePointer()) + ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no } ring0.Halt() diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 308f2a951..9690e3772 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -29,9 +29,12 @@ // Only limited use of the context is done in the assembly stub below, most is // done in the Go handlers. #define SIGINFO_SIGNO 0x0 +#define SIGINFO_CODE 0x8 #define CONTEXT_PC 0x1B8 #define CONTEXT_R0 0xB8 +#define SYS_MMAP 222 + // getTLS returns the value of TPIDR_EL0 register. TEXT ·getTLS(SB),NOSPLIT,$0-8 MRS TPIDR_EL0, R1 @@ -98,6 +101,37 @@ TEXT ·addrOfSighandler(SB), $0-8 MOVD R0, ret+0(FP) RET +// The arguments are the following: +// +// R0 - The signal number. +// R1 - Pointer to siginfo_t structure. +// R2 - Pointer to ucontext structure. +// +TEXT ·sigsysHandler(SB),NOSPLIT,$0 + // si_code should be SYS_SECCOMP. + MOVD SIGINFO_CODE(R1), R7 + CMPW $1, R7 + BNE fallback + + CMPW $SYS_MMAP, R8 + BNE fallback + + MOVD R2, 8(RSP) + BL ·seccompMmapHandler(SB) // Call the handler. + + RET + +fallback: + // Jump to the previous signal handler. + MOVD ·savedHandler(SB), R7 + B (R7) + +// func addrOfSighandler() uintptr +TEXT ·addrOfSigsysHandler(SB), $0-8 + MOVD $·sigsysHandler(SB), R0 + MOVD R0, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 // R0: Fake the old PC as caller diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 0f0c1e73b..e38ca05c0 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -193,36 +193,8 @@ func bluepillHandler(context unsafe.Pointer) { return } - // Increment the fault count. - atomic.AddUint32(&c.faults, 1) - - // For MMIO, the physical address is the first data item. - physical = uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) - if !ok { - c.die(bluepillArchContext(context), "invalid physical address") - return - } - - // We now need to fill in the data appropriately. KVM - // expects us to provide the result of the given MMIO - // operation in the runData struct. This is safe - // because, if a fault occurs here, the same fault - // would have occurred in guest mode. The kernel should - // not create invalid page table mappings. - data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1])) - length := (uintptr)((uint32)(c.runData.data[2])) - write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0 - for i := uintptr(0); i < length; i++ { - b := bytePtr(uintptr(virtual) + i) - if write { - // Write to the given address. - *b = data[i] - } else { - // Read from the given address. - data[i] = *b - } - } + c.die(bluepillArchContext(context), "exit_mmio") + return case _KVM_EXIT_IRQ_WINDOW_OPEN: bluepillStopGuest(c) case _KVM_EXIT_SHUTDOWN: diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index aac0fdffe..ad6863646 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -77,7 +77,11 @@ var ( // OpenDevice opens the KVM device at /dev/kvm and returns the File. func OpenDevice() (*os.File, error) { - f, err := os.OpenFile("/dev/kvm", unix.O_RDWR, 0) + dev, ok := os.LookupEnv("GVISOR_KVM_DEV") + if !ok { + dev = "/dev/kvm" + } + f, err := os.OpenFile(dev, unix.O_RDWR, 0) if err != nil { return nil, fmt.Errorf("error opening /dev/kvm: %v", err) } diff --git a/pkg/sentry/platform/kvm/kvm_safecopy_test.go b/pkg/sentry/platform/kvm/kvm_safecopy_test.go new file mode 100644 index 000000000..fe488e707 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_safecopy_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// FIXME(gvisor.dev/issue/6629): These tests don't pass on ARM64. +// +//go:build amd64 +// +build amd64 + +package kvm + +import ( + "fmt" + "os" + "testing" + "unsafe" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/safecopy" +) + +func testSafecopy(t *testing.T, mapSize uintptr, fileSize uintptr, testFunc func(t *testing.T, c *vCPU, addr uintptr)) { + memfd, err := memutil.CreateMemFD(fmt.Sprintf("kvm_test_%d", os.Getpid()), 0) + if err != nil { + t.Errorf("error creating memfd: %v", err) + } + + memfile := os.NewFile(uintptr(memfd), "kvm_test") + memfile.Truncate(int64(fileSize)) + kvmTest(t, nil, func(c *vCPU) bool { + const n = 10 + mappings := make([]uintptr, n) + defer func() { + for i := 0; i < n && mappings[i] != 0; i++ { + unix.RawSyscall( + unix.SYS_MUNMAP, + mappings[i], mapSize, 0) + } + }() + for i := 0; i < n; i++ { + addr, _, errno := unix.RawSyscall6( + unix.SYS_MMAP, + 0, + mapSize, + unix.PROT_READ|unix.PROT_WRITE, + unix.MAP_SHARED|unix.MAP_FILE, + uintptr(memfile.Fd()), + 0) + if errno != 0 { + t.Errorf("error mapping file: %v", errno) + } + mappings[i] = addr + testFunc(t, c, addr) + } + return false + }) +} + +func TestSafecopySigbus(t *testing.T) { + mapSize := uintptr(faultBlockSize) + fileSize := mapSize - hostarch.PageSize + buf := make([]byte, hostarch.PageSize) + testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) { + want := safecopy.BusError{addr + fileSize} + bluepill(c) + _, err := safecopy.CopyIn(buf, unsafe.Pointer(addr+fileSize)) + if err != want { + t.Errorf("expected error: got %v, want %v", err, want) + } + }) +} + +func TestSafecopy(t *testing.T) { + mapSize := uintptr(faultBlockSize) + fileSize := mapSize + testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) { + want := uint32(0x12345678) + bluepill(c) + _, err := safecopy.SwapUint32(unsafe.Pointer(addr+fileSize-8), want) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + bluepill(c) + val, err := safecopy.LoadUint32(unsafe.Pointer(addr + fileSize - 8)) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if val != want { + t.Errorf("incorrect value: got %x, want %x", val, want) + } + }) +} diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index d67563958..f1f7e4ea4 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -17,16 +17,20 @@ package kvm import ( "fmt" "runtime" + gosync "sync" "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" + "gvisor.dev/gvisor/pkg/seccomp" ktime "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sighandling" "gvisor.dev/gvisor/pkg/sync" ) @@ -35,6 +39,9 @@ type machine struct { // fd is the vm fd. fd int + // machinePoolIndex is the index in the machinePool array. + machinePoolIndex uint32 + // nextSlot is the next slot for setMemoryRegion. // // This must be accessed atomically. If nextSlot is ^uint32(0), then @@ -192,6 +199,10 @@ func (m *machine) newVCPU() *vCPU { return c // Done. } +// readOnlyGuestRegions contains regions that have to be mapped read-only into +// the guest physical address space. Right now, it is used on arm64 only. +var readOnlyGuestRegions []region + // newMachine returns a new VM context. func newMachine(vm int) (*machine, error) { // Create the machine. @@ -227,6 +238,10 @@ func newMachine(vm int) (*machine, error) { m.upperSharedPageTables.MarkReadOnlyShared() m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress) + // Install seccomp rules to trap runtime mmap system calls. They will + // be handled by seccompMmapHandler. + seccompMmapRules(m) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -241,32 +256,11 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) - var physicalRegionsReadOnly []physicalRegion - var physicalRegionsAvailable []physicalRegion - - physicalRegionsReadOnly = rdonlyRegionsForSetMem() - physicalRegionsAvailable = availableRegionsForSetMem() - - // Map all read-only regions. - for _, r := range physicalRegionsReadOnly { - m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) - } - // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to // ensure successful vCPU entry. - applyVirtualRegions(func(vr virtualRegion) { - if excludeVirtualRegion(vr) { - return // skip region. - } - - for _, r := range physicalRegionsReadOnly { - if vr.virtual == r.virtual { - return - } - } - + mapRegion := func(vr region, flags uint32) { for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -280,9 +274,32 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) + m.mapPhysical(physical, length, physicalRegions, flags) virtual += length } + } + + for _, vr := range readOnlyGuestRegions { + mapRegion(vr, _KVM_MEM_READONLY) + } + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + for _, r := range readOnlyGuestRegions { + if vr.virtual == r.virtual { + return + } + } + // Take into account that the stack can grow down. + if vr.filename == "[stack]" { + vr.virtual -= 1 << 20 + vr.length += 1 << 20 + } + + mapRegion(vr.region, 0) + }) // Initialize architecture state. @@ -352,6 +369,10 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg func (m *machine) Destroy() { runtime.SetFinalizer(m, nil) + machinePoolMu.Lock() + machinePool[m.machinePoolIndex].Store(nil) + machinePoolMu.Unlock() + // Destroy vCPUs. for _, c := range m.vCPUsByID { if c == nil { @@ -683,3 +704,72 @@ func (c *vCPU) setSystemTimeLegacy() error { } } } + +const machinePoolSize = 16 + +// machinePool is enumerated from the seccompMmapHandler signal handler +var ( + machinePool [machinePoolSize]machineAtomicPtr + machinePoolLen uint32 + machinePoolMu sync.Mutex + seccompMmapRulesOnce gosync.Once +) + +func sigsysHandler() +func addrOfSigsysHandler() uintptr + +// seccompMmapRules adds seccomp rules to trap mmap system calls that will be +// handled in seccompMmapHandler. +func seccompMmapRules(m *machine) { + seccompMmapRulesOnce.Do(func() { + // Install the handler. + if err := sighandling.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) + } + rules := []seccomp.RuleSet{} + rules = append(rules, []seccomp.RuleSet{ + // Trap mmap system calls and handle them in sigsysGoHandler + { + Rules: seccomp.SyscallRules{ + unix.SYS_MMAP: { + { + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + /* MAP_DENYWRITE is ignored and used only for filtering. */ + seccomp.MaskedEqual(unix.MAP_DENYWRITE, 0), + }, + }, + }, + Action: linux.SECCOMP_RET_TRAP, + }, + }...) + instrs, err := seccomp.BuildProgram(rules, linux.SECCOMP_RET_ALLOW, linux.SECCOMP_RET_ALLOW) + if err != nil { + panic(fmt.Sprintf("failed to build rules: %v", err)) + } + // Perform the actual installation. + if err := seccomp.SetFilter(instrs); err != nil { + panic(fmt.Sprintf("failed to set filter: %v", err)) + } + }) + + machinePoolMu.Lock() + n := atomic.LoadUint32(&machinePoolLen) + i := uint32(0) + for ; i < n; i++ { + if machinePool[i].Load() == nil { + break + } + } + if i == n { + if i == machinePoolSize { + machinePoolMu.Unlock() + panic("machinePool is full") + } + atomic.AddUint32(&machinePoolLen, 1) + } + machinePool[i].Store(m) + m.machinePoolIndex = i + machinePoolMu.Unlock() +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index a96634381..5bc023899 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" ) @@ -72,10 +71,6 @@ type vCPUArchState struct { // // This starts above fixedKernelPCID. PCIDs *pagetables.PCIDs - - // floatingPointState is the floating point state buffer used in guest - // to host transitions. See usage in bluepill_amd64.go. - floatingPointState fpu.State } const ( @@ -152,12 +147,6 @@ func (c *vCPU) initArchState() error { return fmt.Errorf("error setting user registers: %v", errno) } - // Allocate some floating point state save area for the local vCPU. - // This will be saved prior to leaving the guest, and we restore from - // this always. We cannot use the pointer in the context alone because - // we don't know how large the area there is in reality. - c.floatingPointState = fpu.NewState() - // Set the time offset to the host native time. return c.setSystemTime() } @@ -309,22 +298,6 @@ func loadByte(ptr *byte) byte { return *ptr } -// prefaultFloatingPointState touches each page of the floating point state to -// be sure that its physical pages are mapped. -// -// Otherwise the kernel can trigger KVM_EXIT_MMIO and an instruction that -// triggered a fault will be emulated by the kvm kernel code, but it can't -// emulate instructions like xsave and xrstor. -// -//go:nosplit -func prefaultFloatingPointState(data *fpu.State) { - size := len(*data) - for i := 0; i < size; i += hostarch.PageSize { - loadByte(&(*data)[i]) - } - loadByte(&(*data)[size-1]) -} - // SwitchToUser unpacks architectural-details. func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. @@ -355,11 +328,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) // allocations occur. entersyscall() bluepill(c) - // The root table physical page has to be mapped to not fault in iret - // or sysret after switching into a user address space. sysret and - // iret are in the upper half that is global and already mapped. - switchOpts.PageTables.PrefaultRootTable() - prefaultFloatingPointState(switchOpts.FloatingPointState) vector = c.CPU.SwitchToUser(switchOpts) exitsyscall() @@ -522,3 +490,7 @@ func (m *machine) getNewVCPU() *vCPU { } return nil } + +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index de798bb2c..fbacea9ad 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -161,3 +161,15 @@ func (c *vCPU) getSystemRegisters(sregs *systemRegs) unix.Errno { } return 0 } + +//go:nosplit +func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) { + ctx := bluepillArchContext(context) + + // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters. + addr, _, e := unix.RawSyscall6(uintptr(ctx.Rax), uintptr(ctx.Rdi), uintptr(ctx.Rsi), + uintptr(ctx.Rdx), uintptr(ctx.R10)|unix.MAP_DENYWRITE, uintptr(ctx.R8), uintptr(ctx.R9)) + ctx.Rax = uint64(addr) + + return addr, uintptr(ctx.Rsi), e +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 7937a8481..31998a600 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ) @@ -40,10 +39,6 @@ type vCPUArchState struct { // // This starts above fixedKernelPCID. PCIDs *pagetables.PCIDs - - // floatingPointState is the floating point state buffer used in guest - // to host transitions. See usage in bluepill_arm64.go. - floatingPointState fpu.State } const ( @@ -110,18 +105,128 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { return phyRegions } +// archPhysicalRegions fills readOnlyGuestRegions and allocates separate +// physical regions form them. +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + if !vr.accessType.Write { + readOnlyGuestRegions = append(readOnlyGuestRegions, vr.region) + } + }) + + rdRegions := readOnlyGuestRegions[:] + + // Add an unreachable region. + rdRegions = append(rdRegions, region{ + virtual: 0xffffffffffffffff, + length: 0, + }) + + var regions []physicalRegion + addValidRegion := func(r *physicalRegion, virtual, length uintptr) { + if length == 0 { + return + } + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), + }) + } + i := 0 + for _, pr := range physicalRegions { + start := pr.virtual + end := pr.virtual + pr.length + for start < end { + rdRegion := rdRegions[i] + rdStart := rdRegion.virtual + rdEnd := rdRegion.virtual + rdRegion.length + if rdEnd <= start { + i++ + continue + } + if rdStart > start { + newEnd := rdStart + if end < rdStart { + newEnd = end + } + addValidRegion(&pr, start, newEnd-start) + start = rdStart + continue + } + if rdEnd < end { + addValidRegion(&pr, start, rdEnd-start) + start = rdEnd + continue + } + addValidRegion(&pr, start, end-start) + start = end + } + } + + return regions +} + // Get all available physicalRegions. -func availableRegionsForSetMem() (phyRegions []physicalRegion) { - var excludeRegions []region +func availableRegionsForSetMem() []physicalRegion { + var excludedRegions []region applyVirtualRegions(func(vr virtualRegion) { if !vr.accessType.Write { - excludeRegions = append(excludeRegions, vr.region) + excludedRegions = append(excludedRegions, vr.region) } }) - phyRegions = computePhysicalRegions(excludeRegions) + // Add an unreachable region. + excludedRegions = append(excludedRegions, region{ + virtual: 0xffffffffffffffff, + length: 0, + }) - return phyRegions + var regions []physicalRegion + addValidRegion := func(r *physicalRegion, virtual, length uintptr) { + if length == 0 { + return + } + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), + }) + } + i := 0 + for _, pr := range physicalRegions { + start := pr.virtual + end := pr.virtual + pr.length + for start < end { + er := excludedRegions[i] + excludeEnd := er.virtual + er.length + excludeStart := er.virtual + if excludeEnd < start { + i++ + continue + } + if excludeStart < start { + start = excludeEnd + i++ + continue + } + rend := excludeStart + if rend > end { + rend = end + } + addValidRegion(&pr, start, rend-start) + start = excludeEnd + } + } + + return regions } // nonCanonical generates a canonical address return. diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 1a4a9ce7d..e73d5c544 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" ) @@ -159,8 +158,6 @@ func (c *vCPU) initArchState() error { c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs) } - c.floatingPointState = fpu.NewState() - return c.setSystemTime() } @@ -333,3 +330,15 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) } } + +//go:nosplit +func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) { + ctx := bluepillArchContext(context) + + // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters. + addr, _, e := unix.RawSyscall6(uintptr(ctx.Regs[8]), uintptr(ctx.Regs[0]), uintptr(ctx.Regs[1]), + uintptr(ctx.Regs[2]), uintptr(ctx.Regs[3])|unix.MAP_DENYWRITE, uintptr(ctx.Regs[4]), uintptr(ctx.Regs[5])) + ctx.Regs[0] = uint64(addr) + + return addr, uintptr(ctx.Regs[1]), e +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index cc3a1253b..cf3a4e7c9 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -171,3 +171,46 @@ func (c *vCPU) setSignalMask() error { return nil } + +// seccompMmapHandler is a signal handler for runtime mmap system calls +// that are trapped by seccomp. +// +// It executes the mmap syscall with specified arguments and maps a new region +// to the guest. +// +//go:nosplit +func seccompMmapHandler(context unsafe.Pointer) { + addr, length, errno := seccompMmapSyscall(context) + if errno != 0 { + return + } + + for i := uint32(0); i < atomic.LoadUint32(&machinePoolLen); i++ { + m := machinePool[i].Load() + if m == nil { + continue + } + + // Map the new region to the guest. + vr := region{ + virtual: addr, + length: length, + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { + physical, length, ok := translateToPhysical(virtual) + if !ok { + // This must be an invalid region that was + // knocked out by creation of the physical map. + return + } + if virtual+length > vr.virtual+vr.length { + // Cap the length to the end of the area. + length = vr.virtual + vr.length - virtual + } + + // Ensure the physical range is mapped. + m.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) + virtual += length + } + } +} diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index d812e6c26..9864d1258 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -168,6 +168,9 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica } addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd) + // Do arch-specific actions on physical regions. + physicalRegions = archPhysicalRegions(physicalRegions) + // Dump our all physical regions. for _, r := range physicalRegions { log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)", diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go index 6d0ba8252..346a10043 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -30,8 +30,8 @@ import ( func TLSWorks() bool // SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *arch.Registers, fn func()) { - regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) +func SetTestTarget(regs *arch.Registers, fn uintptr) { + regs.Pc = uint64(fn) } // SetTouchTarget sets rax appropriately. diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 7348c29a5..42876245a 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -28,6 +28,11 @@ TEXT ·Getpid(SB),NOSPLIT,$0 SVC RET +TEXT ·AddrOfGetpid(SB),NOSPLIT,$0-8 + MOVD $·Getpid(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·Touch(SB),NOSPLIT,$0 start: MOVD 0(R8), R1 @@ -35,21 +40,41 @@ start: SVC B start +TEXT ·AddrOfTouch(SB),NOSPLIT,$0-8 + MOVD $·Touch(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·HaltLoop(SB),NOSPLIT,$0 start: HLT B start +TEXT ·AddOfHaltLoop(SB),NOSPLIT,$0-8 + MOVD $·HaltLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + // This function simulates a loop of syscall. TEXT ·SyscallLoop(SB),NOSPLIT,$0 start: SVC B start +TEXT ·AddrOfSyscallLoop(SB),NOSPLIT,$0-8 + MOVD $·SyscallLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·SpinLoop(SB),NOSPLIT,$0 start: B start +TEXT ·AddrOfSpinLoop(SB),NOSPLIT,$0-8 + MOVD $·SpinLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·TLSWorks(SB),NOSPLIT,$0-8 NO_LOCAL_POINTERS MOVD $0x6789, R5 @@ -125,6 +150,11 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 SVC RET // never reached +TEXT ·AddrOfTwiddleRegsSyscall(SB),NOSPLIT,$0-8 + MOVD $·TwiddleRegsSyscall(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 TWIDDLE_REGS() MSR R10, TPIDR_EL0 @@ -132,3 +162,8 @@ TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 // Branch to Register branches unconditionally to an address in <Rn>. JMP (R6) // <=> br x6, must fault RET // never reached + +TEXT ·AddrOfTwiddleRegsFault(SB),NOSPLIT,$0-8 + MOVD $·TwiddleRegsFault(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD index 943fa180d..35feb969f 100644 --- a/pkg/sentry/seccheck/BUILD +++ b/pkg/sentry/seccheck/BUILD @@ -8,6 +8,8 @@ go_fieldenum( name = "seccheck_fieldenum", srcs = [ "clone.go", + "execve.go", + "exit.go", "task.go", ], out = "seccheck_fieldenum.go", @@ -29,6 +31,8 @@ go_library( name = "seccheck", srcs = [ "clone.go", + "execve.go", + "exit.go", "seccheck.go", "seccheck_fieldenum.go", "seqatomic_checkerslice_unsafe.go", diff --git a/pkg/sentry/seccheck/execve.go b/pkg/sentry/seccheck/execve.go new file mode 100644 index 000000000..f36e0730e --- /dev/null +++ b/pkg/sentry/seccheck/execve.go @@ -0,0 +1,65 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// ExecveInfo contains information used by the Execve checkpoint. +// +// +fieldenum Execve +type ExecveInfo struct { + // Invoker identifies the invoking thread. + Invoker TaskInfo + + // Credentials are the invoking thread's credentials. + Credentials *auth.Credentials + + // BinaryPath is a path to the executable binary file being switched to in + // the mount namespace in which it was opened. + BinaryPath string + + // Argv is the new process image's argument vector. + Argv []string + + // Env is the new process image's environment variables. + Env []string + + // BinaryMode is the executable binary file's mode. + BinaryMode uint16 + + // BinarySHA256 is the SHA-256 hash of the executable binary file. + // + // Note that this requires reading the entire file into memory, which is + // likely to be extremely slow. + BinarySHA256 [32]byte +} + +// ExecveReq returns fields required by the Execve checkpoint. +func (s *state) ExecveReq() ExecveFieldSet { + return s.execveReq.Load() +} + +// Execve is called at the Execve checkpoint. +func (s *state) Execve(ctx context.Context, mask ExecveFieldSet, info *ExecveInfo) error { + for _, c := range s.getCheckers() { + if err := c.Execve(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/exit.go b/pkg/sentry/seccheck/exit.go new file mode 100644 index 000000000..69cb6911c --- /dev/null +++ b/pkg/sentry/seccheck/exit.go @@ -0,0 +1,57 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" +) + +// ExitNotifyParentInfo contains information used by the ExitNotifyParent +// checkpoint. +// +// +fieldenum ExitNotifyParent +type ExitNotifyParentInfo struct { + // Exiter identifies the exiting thread. Note that by the checkpoint's + // definition, Exiter.ThreadID == Exiter.ThreadGroupID and + // Exiter.ThreadStartTime == Exiter.ThreadGroupStartTime, so requesting + // ThreadGroup* fields is redundant. + Exiter TaskInfo + + // ExitStatus is the exiting thread group's exit status, as reported + // by wait*(). + ExitStatus linux.WaitStatus +} + +// ExitNotifyParentReq returns fields required by the ExitNotifyParent +// checkpoint. +func (s *state) ExitNotifyParentReq() ExitNotifyParentFieldSet { + return s.exitNotifyParentReq.Load() +} + +// ExitNotifyParent is called at the ExitNotifyParent checkpoint. +// +// The ExitNotifyParent checkpoint occurs when a zombied thread group leader, +// not waiting for exit acknowledgement from a non-parent ptracer, becomes the +// last non-dead thread in its thread group and notifies its parent of its +// exiting. +func (s *state) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info *ExitNotifyParentInfo) error { + for _, c := range s.getCheckers() { + if err := c.ExitNotifyParent(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go index b6c9d44ce..e13274096 100644 --- a/pkg/sentry/seccheck/seccheck.go +++ b/pkg/sentry/seccheck/seccheck.go @@ -29,6 +29,8 @@ type Point uint // PointX represents the checkpoint X. const ( PointClone Point = iota + PointExecve + PointExitNotifyParent // Add new Points above this line. pointLength @@ -47,6 +49,8 @@ const ( // registered concurrently with invocations of checkpoints). type Checker interface { Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error + Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error + ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error } // CheckerDefaults may be embedded by implementations of Checker to obtain @@ -58,6 +62,16 @@ func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info Clone return nil } +// Execve implements Checker.Execve. +func (CheckerDefaults) Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error { + return nil +} + +// ExitNotifyParent implements Checker.ExitNotifyParent. +func (CheckerDefaults) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error { + return nil +} + // CheckerReq indicates what checkpoints a corresponding Checker runs at, and // what information it requires at those checkpoints. type CheckerReq struct { @@ -69,7 +83,9 @@ type CheckerReq struct { // All of the following fields indicate what fields in the corresponding // XInfo struct will be requested at the corresponding checkpoint. - Clone CloneFields + Clone CloneFields + Execve ExecveFields + ExitNotifyParent ExitNotifyParentFields } // Global is the method receiver of all seccheck functions. @@ -101,7 +117,9 @@ type state struct { // corresponding XInfo struct have been requested by any registered // checker, are accessed using atomic memory operations, and are mutated // with registrationMu locked. - cloneReq CloneFieldSet + cloneReq CloneFieldSet + execveReq ExecveFieldSet + exitNotifyParentReq ExitNotifyParentFieldSet } // AppendChecker registers the given Checker to execute at checkpoints. The @@ -110,7 +128,11 @@ type state struct { func (s *state) AppendChecker(c Checker, req *CheckerReq) { s.registrationMu.Lock() defer s.registrationMu.Unlock() + s.cloneReq.AddFieldsLoadable(req.Clone) + s.execveReq.AddFieldsLoadable(req.Execve) + s.exitNotifyParentReq.AddFieldsLoadable(req.ExitNotifyParent) + s.appendCheckerLocked(c) for _, p := range req.Points { word, bit := p/32, p%32 diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 7ee89a735..00f925166 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "socket", - srcs = ["socket.go"], + srcs = [ + "socket.go", + "socket_state.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 00a5e729a..6077b2150 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -29,10 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "time" ) -const maxInt = int(^uint(0) >> 1) - // SCMCredentials represents a SCM_CREDENTIALS socket control message. type SCMCredentials interface { transport.CredentialsControlMessage @@ -78,7 +77,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { } // Files implements SCMRights.Files. -func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) { +func (fs *RightsFiles) Files(_ context.Context, max int) (RightsFiles, bool) { n := max var trunc bool if l := len(*fs); n > l { @@ -124,7 +123,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32 break } - fds = append(fds, int32(fd)) + fds = append(fds, fd) } return fds, trunc } @@ -300,8 +299,8 @@ func alignSlice(buf []byte, align uint) []byte { } // PackTimestamp packs a SO_TIMESTAMP socket control message. -func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { - timestampP := linux.NsecToTimeval(timestamp) +func PackTimestamp(t *kernel.Task, timestamp time.Time, buf []byte) []byte { + timestampP := linux.NsecToTimeval(timestamp.UnixNano()) return putCmsgStruct( buf, linux.SOL_SOCKET, @@ -355,6 +354,17 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketIn ) } +// PackIPv6PacketInfo packs an IPV6_PKTINFO socket control message. +func PackIPv6PacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPv6PacketInfo, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IPV6, + linux.IPV6_PKTINFO, + t.Arch().Width(), + packetInfo, + ) +} + // PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { var level uint32 @@ -412,6 +422,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) } + if cmsgs.IP.HasIPv6PacketInfo { + buf = PackIPv6PacketInfo(t, &cmsgs.IP.IPv6PacketInfo, buf) + } + if cmsgs.IP.OriginalDstAddress != nil { buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) } @@ -453,6 +467,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) } + if cmsgs.IP.HasIPv6PacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPv6PacketInfo) + } + if cmsgs.IP.OriginalDstAddress != nil { space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) } @@ -526,7 +544,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var ts linux.Timeval ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) - cmsgs.IP.Timestamp = ts.ToNsecCapped() + cmsgs.IP.Timestamp = ts.ToTime() cmsgs.IP.HasTimestamp = true i += bits.AlignUp(length, width) diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go index 7e28a0cef..1b04e1bbc 100644 --- a/pkg/sentry/socket/control/control_test.go +++ b/pkg/sentry/socket/control/control_test.go @@ -50,7 +50,7 @@ func TestParse(t *testing.T) { want := socket.ControlMessages{ IP: socket.IPControlMessages{ HasTimestamp: true, - Timestamp: ts.ToNsecCapped(), + Timestamp: ts.ToTime(), }, } if diff := cmp.Diff(want, cmsg); diff != "" { diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 1c1e501ba..6e2318f75 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -111,7 +111,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } return readv(s.fd, safemem.IovecsFromBlockSeq(dsts)) })) - return int64(n), err + return n, err } // Write implements fs.FileOperations.Write. @@ -134,7 +134,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } return writev(s.fd, safemem.IovecsFromBlockSeq(srcs)) })) - return int64(n), err + return n, err } // Socket implements socket.Provider.Socket. @@ -180,7 +180,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } @@ -207,7 +207,7 @@ type socketOpsCommon struct { // Release implements fs.FileOperations.Release. func (s *socketOpsCommon) Release(context.Context) { fdnotifier.RemoveFD(int32(s.fd)) - unix.Close(s.fd) + _ = unix.Close(s.fd) } // Readiness implements waiter.Waitable.Readiness. @@ -218,13 +218,13 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { // EventRegister implements waiter.Waitable.EventRegister. func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(s.fd)) + _ = fdnotifier.UpdateFD(int32(s.fd)) } // EventUnregister implements waiter.Waitable.EventUnregister. func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(s.fd)) + _ = fdnotifier.UpdateFD(int32(s.fd)) } // Connect implements socket.Socket.Connect. @@ -316,7 +316,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, if kernel.VFS2Enabled { f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&unix.SOCK_NONBLOCK)) if err != nil { - unix.Close(fd) + _ = unix.Close(fd) return 0, nil, 0, err } defer f.DecRef(t) @@ -328,7 +328,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, } else { f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&unix.SOCK_NONBLOCK != 0) if err != nil { - unix.Close(fd) + _ = unix.Close(fd) return 0, nil, 0, err } defer f.DecRef(t) @@ -343,7 +343,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, } // Bind implements socket.Socket.Bind. -func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -356,12 +356,12 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Listen implements socket.Socket.Listen. -func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error { return syserr.FromError(unix.Listen(s.fd, backlog)) } // Shutdown implements socket.Socket.Shutdown. -func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error { switch how { case unix.SHUT_RD, unix.SHUT_WR, unix.SHUT_RDWR: return syserr.FromError(unix.Shutdown(s.fd, how)) @@ -371,7 +371,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, _ hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -401,7 +401,7 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr case linux.TCP_NODELAY: optlen = sizeofInt32 case linux.TCP_INFO: - optlen = int(linux.SizeOfTCPInfo) + optlen = linux.SizeOfTCPInfo } } @@ -579,7 +579,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) - controlMessages.IP.Timestamp = ts.ToNsecCapped() + controlMessages.IP.Timestamp = ts.ToTime() } case linux.SOL_IP: diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index e3eade180..8d9e73243 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -58,8 +58,8 @@ var nameToID = map[string]stack.TableID{ // DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for // compatibility with netfilter extensions. -func DefaultLinuxTables(seed uint32) *stack.IPTables { - tables := stack.DefaultTables(seed) +func DefaultLinuxTables(seed uint32, clock tcpip.Clock) *stack.IPTables { + tables := stack.DefaultTables(seed, clock) tables.VisitTargets(func(oldTarget stack.Target) stack.Target { switch val := oldTarget.(type) { case *stack.AcceptTarget: diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index ea56f39c1..b9c15daab 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID { } // Action implements stack.Target.Action. -func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index bf5ec4558..075f61cda 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "device.go", "netstack.go", + "netstack_state.go", "netstack_vfs2.go", "provider.go", "provider_vfs2.go", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f79bda922..030c6c8e4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -274,6 +274,7 @@ var Metrics = tcpip.Stats{ ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."), SegmentsAckedWithDSACK: mustCreateMetric("/netstack/tcp/segments_acked_with_dsack", "Number of segments for which DSACK was received."), + SpuriousRecovery: mustCreateMetric("/netstack/tcp/spurious_recovery", "Number of times the connection entered loss recovery spuriously."), }, UDP: tcpip.UDPStats{ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."), @@ -378,9 +379,9 @@ type socketOpsCommon struct { // timestampValid indicates whether timestamp for SIOCGSTAMP has been // set. It is protected by readMu. timestampValid bool - // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only + // timestamp holds the timestamp to use with SIOCTSTAMP. It is only // valid when timestampValid is true. It is protected by readMu. - timestampNS int64 + timestamp time.Time `state:".(int64)"` // TODO(b/153685824): Move this to SocketOptions. // sockOptInq corresponds to TCP_INQ. @@ -410,15 +411,6 @@ var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes() var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes() var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes() -// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the -// netstack representation taking any addresses into account. -func bytesToIPAddress(addr []byte) tcpip.Address { - if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { - return "" - } - return tcpip.Address(addr) -} - // minSockAddrLen returns the minimum length in bytes of a socket address for // the socket's family. func (s *socketOpsCommon) minSockAddrLen() int { @@ -468,7 +460,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { t := kernel.TaskFromContext(ctx) start := t.Kernel().MonotonicClock().Now() deadline := start.Add(v.Timeout) - t.BlockWithDeadline(ch, true, deadline) + _ = t.BlockWithDeadline(ch, true, deadline) } } @@ -488,7 +480,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } // WriteTo implements fs.FileOperations.WriteTo. -func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { +func (s *SocketOperations) WriteTo(_ context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { s.readMu.Lock() defer s.readMu.Unlock() @@ -543,7 +535,7 @@ func (l *limitedPayloader) Len() int { } // ReadFrom implements fs.FileOperations.ReadFrom. -func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { +func (s *SocketOperations) ReadFrom(_ context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { f := limitedPayloader{ inner: io.LimitedReader{ R: r, @@ -654,7 +646,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < 2 { return syserr.ErrInvalidArgument } @@ -672,13 +664,10 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) - if a.Protocol != uint16(s.protocol) { - return syserr.ErrInvalidArgument - } - addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: socket.Ntohs(a.Protocol), } } else { if s.minSockAddrLen() > len(sockaddr) { @@ -717,7 +706,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Listen implements the linux syscall listen(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error { return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog)) } @@ -808,7 +797,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) { // Shutdown implements the linux syscall shutdown(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error { f, err := ConvertShutdown(how) if err != nil { return err @@ -889,7 +878,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, _ linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1374,6 +1363,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) return &v, nil + case linux.IPV6_RECVPKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6ReceivePacketInfo())) + return &v, nil + case linux.IP6T_ORIGINAL_DST: if outLen < sockAddrInet6Size { return nil, syserr.ErrInvalidArgument @@ -1397,11 +1394,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true) + info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, true) if err != nil { return nil, err } @@ -1417,11 +1414,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen) + entries, err := netfilter.GetEntries6(t, stk.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } @@ -1437,8 +1434,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber) @@ -1454,7 +1451,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { +func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, _ int) (marshal.Marshallable, *syserr.Error) { if _, ok := ep.(tcpip.Endpoint); !ok { log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) return nil, syserr.ErrUnknownProtocolOption @@ -1594,11 +1591,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false) + info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, false) if err != nil { return nil, err } @@ -1614,11 +1611,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen) + entries, err := netfilter.GetEntries4(t, stk.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } @@ -1634,8 +1631,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber) @@ -2130,6 +2127,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) return nil + case linux.IPV6_RECVPKTINFO: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(hostarch.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetIPv6ReceivePacketInfo(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2172,12 +2178,12 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true) + return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, true) case linux.IP6T_SO_SET_ADD_COUNTERS: log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported") @@ -2415,12 +2421,12 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false) + return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, false) case linux.IPT_SO_SET_ADD_COUNTERS: log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported") @@ -2519,7 +2525,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, linux.IPV6_RECVPATHMTU, - linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, linux.IPV6_RTHDR, linux.IPV6_RTHDRDSTOPTS, @@ -2588,7 +2593,7 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2600,7 +2605,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2745,6 +2750,8 @@ func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.Contr TClass: readCM.TClass, HasIPPacketInfo: readCM.HasIPPacketInfo, PacketInfo: readCM.PacketInfo, + HasIPv6PacketInfo: readCM.HasIPv6PacketInfo, + IPv6PacketInfo: readCM.IPv6PacketInfo, OriginalDstAddress: readCM.OriginalDstAddress, SockErr: readCM.SockErr, }, @@ -2759,7 +2766,7 @@ func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) { // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. if !s.sockOptTimestamp { s.timestampValid = true - s.timestampNS = cm.Timestamp + s.timestamp = cm.Timestamp } } @@ -2818,7 +2825,7 @@ func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, _ uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { if flags&linux.MSG_ERRQUEUE != 0 { return s.recvErr(t, dst) } @@ -2983,7 +2990,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy return 0, linuxerr.ENOENT } - tv := linux.NsecToTimeval(s.timestampNS) + tv := linux.NsecToTimeval(s.timestamp.UnixNano()) _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err @@ -3090,7 +3097,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc } // interfaceIoctl implements interface requests. -func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error { +func interfaceIoctl(ctx context.Context, _ usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error { var ( iface inet.Interface index int32 @@ -3098,8 +3105,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe ) // Find the relevant device. - stack := inet.StackFromContext(ctx) - if stack == nil { + stk := inet.StackFromContext(ctx) + if stk == nil { return syserr.ErrNoDevice } @@ -3109,7 +3116,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4])) - if iface, ok := stack.Interfaces()[index]; ok { + if iface, ok := stk.Interfaces()[index]; ok { ifr.SetName(iface.Name) return nil } @@ -3117,7 +3124,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // Find the relevant device. - for index, iface = range stack.Interfaces() { + for index, iface = range stk.Interfaces() { if iface.Name == ifr.Name() { found = true break @@ -3150,7 +3157,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } case linux.SIOCGIFFLAGS: - f, err := interfaceStatusFlags(stack, iface.Name) + f, err := interfaceStatusFlags(stk, iface.Name) if err != nil { return err } @@ -3160,7 +3167,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case linux.SIOCGIFADDR: // Copy the IPv4 address out. - for _, addr := range stack.InterfaceAddrs()[index] { + for _, addr := range stk.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. if addr.Family != linux.AF_INET { continue @@ -3196,7 +3203,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case linux.SIOCGIFNETMASK: // Gets the network mask of a device. - for _, addr := range stack.InterfaceAddrs()[index] { + for _, addr := range stk.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. if addr.Family != linux.AF_INET { continue @@ -3228,24 +3235,24 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, _ usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. - stack := inet.StackFromContext(ctx) - if stack == nil { + stk := inet.StackFromContext(ctx) + if stk == nil { return syserr.ErrNoDevice.ToError() } if ifc.Ptr == 0 { - ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq) + ifc.Len = int32(len(stk.Interfaces())) * int32(linux.SizeOfIFReq) return nil } max := ifc.Len ifc.Len = 0 - for key, ifaceAddrs := range stack.InterfaceAddrs() { - iface := stack.Interfaces()[key] + for key, ifaceAddrs := range stk.InterfaceAddrs() { + iface := stk.Interfaces()[key] for _, ifaceAddr := range ifaceAddrs { // Don't write past the end of the buffer. if ifc.Len+int32(linux.SizeOfIFReq) > max { diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/sentry/socket/netstack/netstack_state.go index 529e02a07..591e00d42 100644 --- a/pkg/tcpip/stack/iptables_state.go +++ b/pkg/sentry/socket/netstack/netstack_state.go @@ -1,4 +1,4 @@ -// Copyright 2020 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,29 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack +package netstack import ( "time" ) -// +stateify savable -type unixTime struct { - second int64 - nano int64 +func (s *socketOpsCommon) saveTimestamp() int64 { + s.readMu.Lock() + defer s.readMu.Unlock() + return s.timestamp.UnixNano() } -// saveLastUsed is invoked by stateify. -func (cn *conn) saveLastUsed() unixTime { - return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} -} - -// loadLastUsed is invoked by stateify. -func (cn *conn) loadLastUsed(unix unixTime) { - cn.lastUsed = time.Unix(unix.second, unix.nano) -} - -// beforeSave is invoked by stateify. -func (ct *ConnTrack) beforeSave() { - ct.mu.Lock() +func (s *socketOpsCommon) loadTimestamp(nsec int64) { + s.readMu.Lock() + defer s.readMu.Unlock() + s.timestamp = time.Unix(0, nsec) } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 208ab9909..ea199f223 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // Attach address to interface. nicID := tcpip.NICID(idx) - if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + if err := s.Stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { return syserr.TranslateNetstackError(err).ToError() } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 841d5bd55..d4b80a39d 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "sync/atomic" + "time" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" @@ -51,8 +52,19 @@ type ControlMessages struct { func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { var p linux.ControlMessageIPPacketInfo p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + copy(p.LocalAddr[:], packetInfo.LocalAddr) + copy(p.DestinationAddr[:], packetInfo.DestinationAddr) + return p +} + +// ipv6PacketInfoToLinux converts IPv6PacketInfo from tcpip format to Linux +// format. +func ipv6PacketInfoToLinux(packetInfo tcpip.IPv6PacketInfo) linux.ControlMessageIPv6PacketInfo { + var p linux.ControlMessageIPv6PacketInfo + if n := copy(p.Addr[:], packetInfo.Addr); n != len(p.Addr) { + panic(fmt.Sprintf("got copy(%x, %x) = %d, want = %d", p.Addr, packetInfo.Addr, n, len(p.Addr))) + } + p.NIC = uint32(packetInfo.NIC) return p } @@ -114,7 +126,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa if cmgs.HasOriginalDstAddress { orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) } - return IPControlMessages{ + cm := IPControlMessages{ HasTimestamp: cmgs.HasTimestamp, Timestamp: cmgs.Timestamp, HasInq: cmgs.HasInq, @@ -125,9 +137,16 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa TClass: cmgs.TClass, HasIPPacketInfo: cmgs.HasIPPacketInfo, PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + HasIPv6PacketInfo: cmgs.HasIPv6PacketInfo, OriginalDstAddress: orgDstAddr, SockErr: sockErrCmsgToLinux(cmgs.SockErr), } + + if cm.HasIPv6PacketInfo { + cm.IPv6PacketInfo = ipv6PacketInfoToLinux(cmgs.IPv6PacketInfo) + } + + return cm } // IPControlMessages contains socket control messages for IP sockets. @@ -138,9 +157,9 @@ type IPControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packet used to create - // the read data was received. - Timestamp int64 + // Timestamp is the time that the last packet used to create the read data + // was received. + Timestamp time.Time `state:".(int64)"` // HasInq indicates whether Inq is valid/set. HasInq bool @@ -166,6 +185,12 @@ type IPControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo linux.ControlMessageIPPacketInfo + // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set. + HasIPv6PacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + IPv6PacketInfo linux.ControlMessageIPv6PacketInfo + // OriginalDestinationAddress holds the original destination address // and port of the incoming packet. OriginalDstAddress linux.SockAddr diff --git a/pkg/sentry/socket/socket_state.go b/pkg/sentry/socket/socket_state.go new file mode 100644 index 000000000..32e12b238 --- /dev/null +++ b/pkg/sentry/socket/socket_state.go @@ -0,0 +1,27 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package socket + +import ( + "time" +) + +func (i *IPControlMessages) saveTimestamp() int64 { + return i.Timestamp.UnixNano() +} + +func (i *IPControlMessages) loadTimestamp(nsec int64) { + i.Timestamp = time.Unix(0, nsec) +} diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index a9cedcf5f..188ad3bd9 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -59,12 +59,14 @@ func (q *queue) Close() { // q.WriterQueue.Notify(waiter.WritableEvents) func (q *queue) Reset(ctx context.Context) { q.mu.Lock() - for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { - cur.Release(ctx) - } + dataList := q.dataList q.dataList.Reset() q.used = 0 q.mu.Unlock() + + for cur := dataList.Front(); cur != nil; cur = cur.Next() { + cur.Release(ctx) + } } // DecRef implements RefCounter.DecRef. diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 757ff2a40..4d3f4d556 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -610,9 +610,9 @@ func (i *SyscallInfo) printExit(t *kernel.Task, elapsed time.Duration, output [] if err == nil { // Fill in the output after successful execution. i.post(t, args, retval, output, LogMaximumSize) - rval = fmt.Sprintf("%#x (%v)", retval, elapsed) + rval = fmt.Sprintf("%d (%#x) (%v)", retval, retval, elapsed) } else { - rval = fmt.Sprintf("%#x errno=%d (%s) (%v)", retval, errno, err, elapsed) + rval = fmt.Sprintf("%d (%#x) errno=%d (%s) (%v)", retval, retval, errno, err, elapsed) } switch len(output) { diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go index 3560e66ae..9b8c9a480 100644 --- a/pkg/sentry/time/sampler_arm64.go +++ b/pkg/sentry/time/sampler_arm64.go @@ -30,9 +30,9 @@ func getDefaultArchOverheadCycles() TSCValue { // frqRatio. defaultOverheadCycles of ARM equals to that on // x86 devided by frqRatio cntfrq := getCNTFRQ() - frqRatio := 1000000000 / cntfrq + frqRatio := 1000000000 / float64(cntfrq) overheadCycles := (1 * 1000) / frqRatio - return overheadCycles + return TSCValue(overheadCycles) } // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles. diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index e7073ec87..d9df890c4 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -252,9 +252,9 @@ func (m *MemoryLocked) Copy() (MemoryStats, uint64) { return ms, m.totalLocked() } -// These options control how much total memory the is reported to the application. -// They may only be set before the application starts executing, and must not -// be modified. +// These options control how much total memory the is reported to the +// application. They may only be set before the application starts executing, +// and must not be modified. var ( // MinimumTotalMemoryBytes is the minimum reported total system memory. MinimumTotalMemoryBytes uint64 = 2 << 30 // 2 GB diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 04bc4d10c..fefd0fc9c 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -135,12 +135,16 @@ func (ep *EpollInstance) Readiness(mask waiter.EventMask) waiter.EventMask { return 0 } ep.mu.Lock() - for epi := ep.ready.Front(); epi != nil; epi = epi.Next() { + var next *epollInterest + for epi := ep.ready.Front(); epi != nil; epi = next { + next = epi.Next() wmask := waiter.EventMaskFromLinux(epi.mask) if epi.key.file.Readiness(wmask)&wmask != 0 { ep.mu.Unlock() return waiter.ReadableEvents } + ep.ready.Remove(epi) + epi.ready = false } ep.mu.Unlock() return 0 diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 5dab069ed..452f5f1f9 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -17,6 +17,7 @@ package vfs import ( "bytes" "io" + "math" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -399,6 +400,9 @@ func (fd *DynamicBytesFileDescriptionImpl) Write(ctx context.Context, src userme // GenericConfigureMMap may be used by most implementations of // FileDescriptionImpl.ConfigureMMap. func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error { + if opts.Offset+opts.Length > math.MaxInt64 { + return linuxerr.EOVERFLOW + } opts.Mappable = m opts.MappingIdentity = fd fd.IncRef() diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 7fd7f000d..40aff2927 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -223,6 +223,12 @@ func (rp *ResolvingPath) Final() bool { return rp.curPart == 0 && !rp.pit.NextOk() } +// Pit returns a copy of rp's current path iterator. Modifying the iterator +// does not change rp. +func (rp *ResolvingPath) Pit() fspath.Iterator { + return rp.pit +} + // Component returns the current path component in the stream represented by // rp. // diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go index 8998a82dd..8a6ced365 100644 --- a/pkg/sentry/vfs/save_restore.go +++ b/pkg/sentry/vfs/save_restore.go @@ -15,7 +15,6 @@ package vfs import ( - "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -24,6 +23,18 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// ErrCorruption indicates a failed restore due to external file system state in +// corruption. +type ErrCorruption struct { + // Err is the wrapped error. + Err error +} + +// Error returns a sensible description of the restore error. +func (e ErrCorruption) Error() string { + return "restore failed due to external file system state in corruption: " + e.Err.Error() +} + // FilesystemImplSaveRestoreExtension is an optional extension to // FilesystemImpl. type FilesystemImplSaveRestoreExtension interface { @@ -37,38 +48,30 @@ type FilesystemImplSaveRestoreExtension interface { // PrepareSave prepares all filesystems for serialization. func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error { - failures := 0 for fs := range vfs.getFilesystems() { if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { if err := ext.PrepareSave(ctx); err != nil { - ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err) - failures++ + fs.DecRef(ctx) + return err } } fs.DecRef(ctx) } - if failures != 0 { - return fmt.Errorf("%d filesystems failed to prepare for serialization", failures) - } return nil } // CompleteRestore completes restoration from checkpoint for all filesystems // after deserialization. func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error { - failures := 0 for fs := range vfs.getFilesystems() { if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { if err := ext.CompleteRestore(ctx, *opts); err != nil { - ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err) - failures++ + fs.DecRef(ctx) + return err } } fs.DecRef(ctx) } - if failures != 0 { - return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures) - } return nil } diff --git a/pkg/shim/service.go b/pkg/shim/service.go index 24e3b7a82..0980d964e 100644 --- a/pkg/shim/service.go +++ b/pkg/shim/service.go @@ -77,6 +77,8 @@ const ( // shimAddressPath is the relative path to a file that contains the address // to the shim UDS. See service.shimAddress. shimAddressPath = "address" + + cgroupParentAnnotation = "dev.gvisor.spec.cgroup-parent" ) // New returns a new shim service that can be used via GRPC. @@ -952,7 +954,7 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C if err != nil { return nil, fmt.Errorf("update volume annotations: %w", err) } - updated = updateCgroup(spec) || updated + updated = setPodCgroup(spec) || updated if updated { if err := utils.WriteSpec(r.Bundle, spec); err != nil { @@ -980,12 +982,13 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C return p, nil } -// updateCgroup updates cgroup path for the sandbox to make the sandbox join the -// pod cgroup and not the pause container cgroup. Returns true if the spec was -// modified. Ex.: -// /kubepods/burstable/pod123/abc => kubepods/burstable/pod123 +// setPodCgroup searches for the pod cgroup path inside the container's cgroup +// path. If found, it's set as an annotation in the spec. This is done so that +// the sandbox joins the pod cgroup. Otherwise, the sandbox would join the pause +// container cgroup. Returns true if the spec was modified. Ex.: +// /kubepods/burstable/pod123/container123 => kubepods/burstable/pod123 // -func updateCgroup(spec *specs.Spec) bool { +func setPodCgroup(spec *specs.Spec) bool { if !utils.IsSandbox(spec) { return false } @@ -1009,7 +1012,10 @@ func updateCgroup(spec *specs.Spec) bool { if spec.Linux.CgroupsPath == path { return false } - spec.Linux.CgroupsPath = path + if spec.Annotations == nil { + spec.Annotations = make(map[string]string) + } + spec.Annotations[cgroupParentAnnotation] = path return true } } diff --git a/pkg/shim/service_test.go b/pkg/shim/service_test.go index 2d9f07e02..4b4410a58 100644 --- a/pkg/shim/service_test.go +++ b/pkg/shim/service_test.go @@ -40,12 +40,12 @@ func TestCgroupPath(t *testing.T) { { name: "no-container", path: "foo/pod123", - want: "foo/pod123", + want: "", }, { name: "no-container-absolute", path: "/foo/pod123", - want: "/foo/pod123", + want: "", }, { name: "double-pod", @@ -70,7 +70,7 @@ func TestCgroupPath(t *testing.T) { { name: "no-pod", path: "/foo/nopod123/container", - want: "/foo/nopod123/container", + want: "", }, } { t.Run(tc.name, func(t *testing.T) { @@ -79,12 +79,12 @@ func TestCgroupPath(t *testing.T) { CgroupsPath: tc.path, }, } - updated := updateCgroup(&spec) - if spec.Linux.CgroupsPath != tc.want { - t.Errorf("updateCgroup(%q), want: %q, got: %q", tc.path, tc.want, spec.Linux.CgroupsPath) + updated := setPodCgroup(&spec) + if got := spec.Annotations[cgroupParentAnnotation]; got != tc.want { + t.Errorf("setPodCgroup(%q), want: %q, got: %q", tc.path, tc.want, got) } - if shouldUpdate := tc.path != tc.want; shouldUpdate != updated { - t.Errorf("updateCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate) + if shouldUpdate := len(tc.want) > 0; shouldUpdate != updated { + t.Errorf("setPodCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate) } }) } @@ -113,8 +113,8 @@ func TestCgroupNoUpdate(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - if updated := updateCgroup(tc.spec); updated { - t.Errorf("updateCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated) + if updated := setPodCgroup(tc.spec); updated { + t.Errorf("setPodCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated) } }) } diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sighandling/BUILD index 1790d57c9..72f10f982 100644 --- a/pkg/sentry/sighandling/BUILD +++ b/pkg/sighandling/BUILD @@ -8,7 +8,7 @@ go_library( "sighandling.go", "sighandling_unsafe.go", ], - visibility = ["//pkg/sentry:internal"], + visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sighandling/sighandling.go index bdaf8af29..bdaf8af29 100644 --- a/pkg/sentry/sighandling/sighandling.go +++ b/pkg/sighandling/sighandling.go diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sighandling/sighandling_unsafe.go index 3fe5c6770..7deeda042 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sighandling/sighandling_unsafe.go @@ -15,6 +15,7 @@ package sighandling import ( + "fmt" "unsafe" "golang.org/x/sys/unix" @@ -37,3 +38,36 @@ func IgnoreChildStop() error { return nil } + +// ReplaceSignalHandler replaces the existing signal handler for the provided +// signal with the function pointer at `handler`. This bypasses the Go runtime +// signal handlers, and should only be used for low-level signal handlers where +// use of signal.Notify is not appropriate. +// +// It stores the value of the previously set handler in previous. +func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { + var sa linux.SigAction + const maskLen = 8 + + // Get the existing signal handler information, and save the current + // handler. Once we replace it, we will use this pointer to fall back to + // it when we receive other signals. + if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { + return e + } + + // Fail if there isn't a previous handler. + if sa.Handler == 0 { + return fmt.Errorf("previous handler for signal %x isn't set", sig) + } + + *previous = uintptr(sa.Handler) + + // Install our own handler. + sa.Handler = uint64(handler) + if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { + return e + } + + return nil +} diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 73791b456..517f16329 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -26,6 +26,7 @@ go_library( "rwmutex_unsafe.go", "seqcount.go", "sync.go", + "wait.go", ], marshal = False, stateify = False, diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go index 82b6df18c..7b9c2a4db 100644 --- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go +++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go @@ -37,6 +37,8 @@ func (p *AtomicPtr) loadPtr(v *Value) { // Load returns the value set by the most recent Store. It returns nil if there // has been no previous call to Store. +// +//go:nosplit func (p *AtomicPtr) Load() *Value { return (*Value)(atomic.LoadPointer(&p.ptr)) } diff --git a/pkg/sync/wait.go b/pkg/sync/wait.go new file mode 100644 index 000000000..f8e7742a5 --- /dev/null +++ b/pkg/sync/wait.go @@ -0,0 +1,58 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +// WaitGroupErr is similar to WaitGroup but allows goroutines to report error. +// Only the first error is retained and reported back. +// +// Example usage: +// wg := WaitGroupErr{} +// wg.Add(1) +// go func() { +// defer wg.Done() +// if err := ...; err != nil { +// wg.ReportError(err) +// return +// } +// }() +// return wg.Error() +// +type WaitGroupErr struct { + WaitGroup + + // mu protects firstErr. + mu Mutex + + // firstErr holds the first error reported. nil is no error occurred. + firstErr error +} + +// ReportError reports an error. Note it does not call Done(). +func (w *WaitGroupErr) ReportError(err error) { + w.mu.Lock() + defer w.mu.Unlock() + if w.firstErr == nil { + w.firstErr = err + } +} + +// Error waits for the counter to reach 0 and returns the first reported error +// if any. +func (w *WaitGroupErr) Error() error { + w.Wait() + w.mu.Lock() + defer w.mu.Unlock() + return w.firstErr +} diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index dbe4506cc..b98de54c5 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -25,6 +25,7 @@ go_library( "stdclock.go", "stdclock_state.go", "tcpip.go", + "tcpip_state.go", "timer.go", ], visibility = ["//visibility:public"], diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 010e2e833..1f2bcaf65 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -19,6 +19,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net" "time" @@ -471,9 +472,9 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc return DialContextTCP(context.Background(), s, addr, network) } -// DialContextTCP creates a new TCPConn connected to the specified address -// with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { +// DialTCPWithBind creates a new TCPConn connected to the specified +// remoteAddress with its local address bound to localAddr. +func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { // Create TCP endpoint, then connect. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) @@ -494,7 +495,14 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, default: } - err = ep.Connect(addr) + // Bind before connect if requested. + if localAddr != (tcpip.FullAddress{}) { + if err = ep.Bind(localAddr); err != nil { + return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err) + } + } + + err = ep.Connect(remoteAddr) if _, ok := err.(*tcpip.ErrConnectStarted); ok { select { case <-ctx.Done(): @@ -510,7 +518,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, return nil, &net.OpError{ Op: "connect", Net: "tcp", - Addr: fullToTCPAddr(addr), + Addr: fullToTCPAddr(remoteAddr), Err: errors.New(err.String()), } } @@ -518,6 +526,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, return NewTCPConn(&wq, ep), nil } +// DialContextTCP creates a new TCPConn connected to the specified address +// with the option of adding cancellation and timeouts. +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { + return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network) +} + // A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements // net.Conn and net.PacketConn. type UDPConn struct { diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 48b24692b..dcc9fff17 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) { addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { @@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } done := make(chan struct{}) @@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue @@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue @@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) { ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip1.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err) + } ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip2.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err) + } done := make(chan struct{}) fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { @@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) { addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { @@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) { ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip1.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err) + } ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip2.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err) + } c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) if err != nil { @@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) if err != nil { @@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) + return nil, nil, nil, fmt.Errorf("NewListener: %w", err) } c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) if err != nil { l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) + return nil, nil, nil, fmt.Errorf("DialTCP: %w", err) } c2, err = l.Accept() if err != nil { l.Close() c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) + return nil, nil, nil, fmt.Errorf("l.Accept: %w", err) } stop = func() { @@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { if err := l.Close(); err != nil { stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) + return nil, nil, nil, fmt.Errorf("l.Close(): %w", err) } return c1, c2, stop, nil @@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { time.Sleep(time.Second) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 2f34bf8dd..24c2c3e6b 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -324,6 +324,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field +// in ControlMessages. +func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPv6PacketInfo { + t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo) + } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" { + t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff) + } + } +} + // ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress // field in ControlMessages. func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index dcc549c7b..7baaf0d17 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -208,6 +208,15 @@ var IPv4EmptySubnet = func() tcpip.Subnet { return subnet }() +// IPv4LoopbackSubnet is the loopback subnet for IPv4. +var IPv4LoopbackSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(tcpip.Address("\x7f\x00\x00\x00"), tcpip.AddressMask("\xff\x00\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + // IPVersion returns the version of IP used in the given packet. It returns -1 // if the packet is not large enough to contain the version field. func IPVersion(b []byte) int { diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 1c913b5e1..80a9ad6be 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -110,6 +110,16 @@ traverseExtensions: switch extHdr := extHdr.(type) { case header.IPv6FragmentExtHdr: + if extHdr.IsAtomic() { + // This fragment extension header indicates that this packet is an + // atomic fragment. An atomic fragment is a fragment that contains + // all the data required to reassemble a full packet. As per RFC 6946, + // atomic fragments must not interfere with "normal" fragmented traffic + // so we skip processing the fragment instead of feeding it through the + // reassembly process below. + continue + } + if fragID == 0 && fragOffset == 0 && !fragMore { fragID = extHdr.ID() fragOffset = extHdr.FragmentOffset() @@ -175,3 +185,61 @@ func TCP(pkt *stack.PacketBuffer) bool { pkt.TransportProtocolNumber = header.TCPProtocolNumber return ok } + +// ICMPv4 populates the packet buffer's transport header with an ICMPv4 header, +// if present. +// +// Returns true if an ICMPv4 header was successfully parsed. +func ICMPv4(pkt *stack.PacketBuffer) bool { + if _, ok := pkt.TransportHeader().Consume(header.ICMPv4MinimumSize); ok { + pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber + return true + } + return false +} + +// ICMPv6 populates the packet buffer's transport header with an ICMPv4 header, +// if present. +// +// Returns true if an ICMPv6 header was successfully parsed. +func ICMPv6(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize) + if !ok { + return false + } + + h := header.ICMPv6(hdr) + switch h.Type() { + case header.ICMPv6RouterSolicit, + header.ICMPv6RouterAdvert, + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborAdvert, + header.ICMPv6RedirectMsg: + size := pkt.Data().Size() + if _, ok := pkt.TransportHeader().Consume(size); !ok { + panic(fmt.Sprintf("expected to consume the full data of size = %d bytes into transport header", size)) + } + case header.ICMPv6MulticastListenerQuery, + header.ICMPv6MulticastListenerReport, + header.ICMPv6MulticastListenerDone: + size := header.ICMPv6HeaderSize + header.MLDMinimumSize + if _, ok := pkt.TransportHeader().Consume(size); !ok { + return false + } + case header.ICMPv6DstUnreachable, + header.ICMPv6PacketTooBig, + header.ICMPv6TimeExceeded, + header.ICMPv6ParamProblem, + header.ICMPv6EchoRequest, + header.ICMPv6EchoReply: + fallthrough + default: + if _, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize); !ok { + // Checked above if the packet buffer holds at least the minimum size for + // an ICMPv6 packet. + panic(fmt.Sprintf("expected to consume %d bytes", header.ICMPv6MinimumSize)) + } + } + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + return true +} diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 3ed0aa3fe..c67ca98ea 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -123,4 +123,6 @@ func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber } // WriteRawPacket implements stack.LinkEndpoint. -func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } +func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + return e.WritePacket(stack.RouteInfo{}, 0, pkt) +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 87a0b9a62..e53789d92 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -152,10 +152,22 @@ type PollEvent struct { // no data is available, it will block in a poll() syscall until the file // descriptor becomes readable. func BlockingRead(fd int, b []byte) (int, tcpip.Error) { + n, err := BlockingReadUntranslated(fd, b) + if err != 0 { + return n, TranslateErrno(err) + } + return n, nil +} + +// BlockingReadUntranslated reads from a file descriptor that is set up as +// non-blocking. If no data is available, it will block in a poll() syscall +// until the file descriptor becomes readable. It returns the raw unix.Errno +// value returned by the underlying syscalls. +func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) { for { n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) if e == 0 { - return int(n), nil + return int(n), 0 } event := PollEvent{ @@ -165,7 +177,7 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) { _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != unix.EINTR { - return 0, TranslateErrno(e) + return 0, e } } } diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 4215ee852..af755473c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -5,19 +5,27 @@ package(licenses = ["notice"]) go_library( name = "sharedmem", srcs = [ + "queuepair.go", "rx.go", + "server_rx.go", + "server_tx.go", "sharedmem.go", + "sharedmem_server.go", "sharedmem_unsafe.go", "tx.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/cleanup", + "//pkg/eventfd", "//pkg/log", + "//pkg/memutil", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", + "//pkg/tcpip/link/sharedmem/pipe", "//pkg/tcpip/link/sharedmem/queue", "//pkg/tcpip/stack", "@org_golang_x_sys//unix:go_default_library", @@ -26,9 +34,7 @@ go_library( go_test( name = "sharedmem_test", - srcs = [ - "sharedmem_test.go", - ], + srcs = ["sharedmem_test.go"], library = ":sharedmem", deps = [ "//pkg/sync", @@ -41,3 +47,22 @@ go_test( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "sharedmem_server_test", + size = "small", + srcs = ["sharedmem_server_test.go"], + deps = [ + ":sharedmem", + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/header", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go index 696e6c9e5..a78826ebc 100644 --- a/pkg/tcpip/link/sharedmem/queue/rx.go +++ b/pkg/tcpip/link/sharedmem/queue/rx.go @@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool { } r.tx.Flush() - return true } @@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool { func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) { for { outBufs := bufs - // Pull the next descriptor from the rx pipe. b := r.rx.Pull() if b == nil { diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go new file mode 100644 index 000000000..b12647fdd --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queuepair.go @@ -0,0 +1,199 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "fmt" + "io/ioutil" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" +) + +const ( + // defaultQueueDataSize is the size of the shared memory data region that + // holds the scatter/gather buffers. + defaultQueueDataSize = 1 << 20 // 1MiB + + // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors. + // + // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU) + // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data + // area. Which means the pipe needs to be big enough to hold 819 + // descriptors. + // + // Each descriptor is approximately 8 (slot descriptor in pipe) + + // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is + // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.) + // + // Which means we need approximately 36*819 ~ 29 KiB to store all packet + // descriptors. We could go with a 32 KiB pipe but to give it some slack in + // how the upper layer may make use of the scatter gather buffers we double + // this to hold enough descriptors. + defaultQueuePipeSize = 64 << 10 // 64KiB + + // defaultSharedDataSize is the size of the sharedData region used to + // enable/disable notifications. + defaultSharedDataSize = 4 << 10 // 4KiB +) + +// A QueuePair represents a pair of TX/RX queues. +type QueuePair struct { + // txCfg is the QueueConfig to be used for transmit queue. + txCfg QueueConfig + + // rxCfg is the QueueConfig to be used for receive queue. + rxCfg QueueConfig +} + +// NewQueuePair creates a shared memory QueuePair. +func NewQueuePair() (*QueuePair, error) { + txCfg, err := createQueueFDs(queueSizes{ + dataSize: defaultQueueDataSize, + txPipeSize: defaultQueuePipeSize, + rxPipeSize: defaultQueuePipeSize, + sharedDataSize: defaultSharedDataSize, + }) + + if err != nil { + return nil, fmt.Errorf("failed to create tx queue: %s", err) + } + + rxCfg, err := createQueueFDs(queueSizes{ + dataSize: defaultQueueDataSize, + txPipeSize: defaultQueuePipeSize, + rxPipeSize: defaultQueuePipeSize, + sharedDataSize: defaultSharedDataSize, + }) + + if err != nil { + closeFDs(txCfg) + return nil, fmt.Errorf("failed to create rx queue: %s", err) + } + + return &QueuePair{ + txCfg: txCfg, + rxCfg: rxCfg, + }, nil +} + +// Close closes underlying tx/rx queue fds. +func (q *QueuePair) Close() { + closeFDs(q.txCfg) + closeFDs(q.rxCfg) +} + +// TXQueueConfig returns the QueueConfig for the receive queue. +func (q *QueuePair) TXQueueConfig() QueueConfig { + return q.txCfg +} + +// RXQueueConfig returns the QueueConfig for the transmit queue. +func (q *QueuePair) RXQueueConfig() QueueConfig { + return q.rxCfg +} + +type queueSizes struct { + dataSize int64 + txPipeSize int64 + rxPipeSize int64 + sharedDataSize int64 +} + +func createQueueFDs(s queueSizes) (QueueConfig, error) { + success := false + var eventFD eventfd.Eventfd + var dataFD, txPipeFD, rxPipeFD, sharedDataFD int + defer func() { + if success { + return + } + closeFDs(QueueConfig{ + EventFD: eventFD, + DataFD: dataFD, + TxPipeFD: txPipeFD, + RxPipeFD: rxPipeFD, + SharedDataFD: sharedDataFD, + }) + }() + eventFD, err := eventfd.Create() + if err != nil { + return QueueConfig{}, fmt.Errorf("eventfd failed: %v", err) + } + dataFD, err = createFile(s.dataSize, false) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err) + } + txPipeFD, err = createFile(s.txPipeSize, true) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err) + } + rxPipeFD, err = createFile(s.rxPipeSize, true) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err) + } + sharedDataFD, err = createFile(s.sharedDataSize, false) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err) + } + success = true + return QueueConfig{ + EventFD: eventFD, + DataFD: dataFD, + TxPipeFD: txPipeFD, + RxPipeFD: rxPipeFD, + SharedDataFD: sharedDataFD, + }, nil +} + +func createFile(size int64, initQueue bool) (fd int, err error) { + const tmpDir = "/dev/shm/" + f, err := ioutil.TempFile(tmpDir, "sharedmem_test") + if err != nil { + return -1, fmt.Errorf("TempFile failed: %v", err) + } + defer f.Close() + unix.Unlink(f.Name()) + + if initQueue { + // Write the "slot-free" flag in the initial queue. + if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil { + return -1, fmt.Errorf("WriteAt failed: %v", err) + } + } + + fd, err = unix.Dup(int(f.Fd())) + if err != nil { + return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err) + } + + if err := unix.Ftruncate(fd, size); err != nil { + unix.Close(fd) + return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err) + } + + return fd, nil +} + +func closeFDs(c QueueConfig) { + unix.Close(c.DataFD) + c.EventFD.Close() + unix.Close(c.TxPipeFD) + unix.Close(c.RxPipeFD) + unix.Close(c.SharedDataFD) +} diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go index e882a128c..87747dcc7 100644 --- a/pkg/tcpip/link/sharedmem/rx.go +++ b/pkg/tcpip/link/sharedmem/rx.go @@ -21,7 +21,7 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -30,7 +30,7 @@ type rx struct { data []byte sharedData []byte q queue.Rx - eventFD int + eventFD eventfd.Eventfd } // init initializes all state needed by the rx queue based on the information @@ -68,7 +68,7 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error { // Duplicate the eventFD so that caller can close it but we can still // use it. - efd, err := unix.Dup(c.EventFD) + efd, err := c.EventFD.Dup() if err != nil { unix.Munmap(txPipe) unix.Munmap(rxPipe) @@ -77,16 +77,6 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error { return err } - // Set the eventfd as non-blocking. - if err := unix.SetNonblock(efd, true); err != nil { - unix.Munmap(txPipe) - unix.Munmap(rxPipe) - unix.Munmap(data) - unix.Munmap(sharedData) - unix.Close(efd) - return err - } - // Initialize state based on buffers. r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData)) r.data = data @@ -105,7 +95,13 @@ func (r *rx) cleanup() { unix.Munmap(r.data) unix.Munmap(r.sharedData) - unix.Close(r.eventFD) + r.eventFD.Close() +} + +// notify writes to the tx.eventFD to indicate to the peer that there is data to +// be read. +func (r *rx) notify() { + r.eventFD.Notify() } // postAndReceive posts the provided buffers (if any), and then tries to read @@ -122,8 +118,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue. if len(b) != 0 && !r.q.PostBuffers(b) { r.q.EnableNotification() for !r.q.PostBuffers(b) { - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) + r.eventFD.Wait() if atomic.LoadUint32(stopRequested) != 0 { r.q.DisableNotification() return nil, 0 @@ -147,8 +142,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue. } // Wait for notification. - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) + r.eventFD.Wait() if atomic.LoadUint32(stopRequested) != 0 { r.q.DisableNotification() return nil, 0 diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go new file mode 100644 index 000000000..6ea21ffd1 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/server_rx.go @@ -0,0 +1,142 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/eventfd" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +type serverRx struct { + // packetPipe represents the receive end of the pipe that carries the packet + // descriptors sent by the client. + packetPipe pipe.Rx + + // completionPipe represents the transmit end of the pipe that will carry + // completion notifications from the server to the client. + completionPipe pipe.Tx + + // data represents the buffer area where the packet payload is held. + data []byte + + // eventFD is used to notify the peer when transmission is completed. + eventFD eventfd.Eventfd + + // sharedData the memory region to use to enable/disable notifications. + sharedData []byte +} + +// init initializes all state needed by the serverTx queue based on the +// information provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (s *serverRx) init(c *QueueConfig) error { + // Map in all buffers. + packetPipeMem, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) }) + defer cu.Clean() + + completionPipeMem, err := getBuffer(c.RxPipeFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(completionPipeMem) }) + + data, err := getBuffer(c.DataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(data) }) + + sharedData, err := getBuffer(c.SharedDataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(sharedData) }) + + // Duplicate the eventFD so that caller can close it but we can still + // use it. + efd, err := c.EventFD.Dup() + if err != nil { + return err + } + cu.Add(func() { efd.Close() }) + + s.packetPipe.Init(packetPipeMem) + s.completionPipe.Init(completionPipeMem) + s.data = data + s.eventFD = efd + s.sharedData = sharedData + + cu.Release() + return nil +} + +func (s *serverRx) cleanup() { + unix.Munmap(s.packetPipe.Bytes()) + unix.Munmap(s.completionPipe.Bytes()) + unix.Munmap(s.data) + unix.Munmap(s.sharedData) + s.eventFD.Close() +} + +// completionNotificationSize is size in bytes of a completion notification sent +// on the completion queue after a transmitted packet has been handled. +const completionNotificationSize = 8 + +// receive receives a single packet from the packetPipe. +func (s *serverRx) receive() []byte { + desc := s.packetPipe.Pull() + if desc == nil { + return nil + } + + pktInfo := queue.DecodeTxPacketHeader(desc) + contents := make([]byte, 0, pktInfo.Size) + toCopy := pktInfo.Size + for i := 0; i < pktInfo.BufferCount; i++ { + txBuf := queue.DecodeTxBufferHeader(desc, i) + if txBuf.Size <= toCopy { + contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...) + toCopy -= txBuf.Size + continue + } + contents = append(contents, s.data[txBuf.Offset:][:toCopy]...) + break + } + + // Flush to let peer know that slots queued for transmission have been handled + // and its free to reuse the slots. + s.packetPipe.Flush() + // Encode packet completion. + b := s.completionPipe.Push(completionNotificationSize) + queue.EncodeTxCompletion(b, pktInfo.ID) + s.completionPipe.Flush() + return contents +} + +func (s *serverRx) waitForPackets() { + s.eventFD.Wait() +} diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go new file mode 100644 index 000000000..13a82903f --- /dev/null +++ b/pkg/tcpip/link/sharedmem/server_tx.go @@ -0,0 +1,175 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/eventfd" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +// serverTx represents the server end of the sharedmem queue and is used to send +// packets to the peer in the buffers posted by the peer in the fillPipe. +type serverTx struct { + // fillPipe represents the receive end of the pipe that carries the RxBuffers + // posted by the peer. + fillPipe pipe.Rx + + // completionPipe represents the transmit end of the pipe that carries the + // descriptors for filled RxBuffers. + completionPipe pipe.Tx + + // data represents the buffer area where the packet payload is held. + data []byte + + // eventFD is used to notify the peer when fill requests are fulfilled. + eventFD eventfd.Eventfd + + // sharedData the memory region to use to enable/disable notifications. + sharedData []byte +} + +// init initializes all tstate needed by the serverTx queue based on the +// information provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (s *serverTx) init(c *QueueConfig) error { + // Map in all buffers. + fillPipeMem, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) }) + defer cu.Clean() + + completionPipeMem, err := getBuffer(c.RxPipeFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(completionPipeMem) }) + + data, err := getBuffer(c.DataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(data) }) + + sharedData, err := getBuffer(c.SharedDataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(sharedData) }) + + // Duplicate the eventFD so that caller can close it but we can still + // use it. + efd, err := c.EventFD.Dup() + if err != nil { + return err + } + cu.Add(func() { efd.Close() }) + + cu.Release() + + s.fillPipe.Init(fillPipeMem) + s.completionPipe.Init(completionPipeMem) + s.data = data + s.eventFD = efd + s.sharedData = sharedData + + return nil +} + +func (s *serverTx) cleanup() { + unix.Munmap(s.fillPipe.Bytes()) + unix.Munmap(s.completionPipe.Bytes()) + unix.Munmap(s.data) + unix.Munmap(s.sharedData) + s.eventFD.Close() +} + +// fillPacket copies the data in the provided views into buffers pulled from the +// fillPipe and returns a slice of RxBuffers that contain the copied data as +// well as the total number of bytes copied. +// +// To avoid allocations the filledBuffers are appended to the buffers slice +// which will be grown as required. +func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) { + filledBuffers = buffers[:0] + // fillBuffer copies as much of the views as possible into the provided buffer + // and returns any left over views (if any). + fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) { + if len(views) == 0 { + return nil + } + availBytes := buffer.Size + copied := uint64(0) + for availBytes > 0 && len(views) > 0 { + n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0]) + views[0].TrimFront(n) + if !views[0].IsEmpty() { + break + } + views = views[1:] + copied += uint64(n) + availBytes -= uint32(n) + } + buffer.Size = uint32(copied) + return views + } + + for len(views) > 0 { + var b []byte + // Spin till we get a free buffer reposted by the peer. + for { + if b = s.fillPipe.Pull(); b != nil { + break + } + } + rxBuffer := queue.DecodeRxBufferHeader(b) + // Copy the packet into the posted buffer. + views = fillBuffer(&rxBuffer, views) + totalCopied += rxBuffer.Size + filledBuffers = append(filledBuffers, rxBuffer) + } + + return filledBuffers, totalCopied +} + +func (s *serverTx) transmit(views []buffer.View) bool { + buffers := make([]queue.RxBuffer, 8) + buffers, totalCopied := s.fillPacket(views, buffers) + b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers))) + if b == nil { + return false + } + queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */) + for i := 0; i < len(buffers); i++ { + queue.EncodeRxCompletionBuffer(b, i, buffers[i]) + } + s.completionPipe.Flush() + s.fillPipe.Flush() + return true +} + +func (s *serverTx) notify() { + s.eventFD.Notify() +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 66efe6472..b75522a51 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -24,14 +24,16 @@ package sharedmem import ( + "fmt" "sync/atomic" - "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,7 +49,7 @@ type QueueConfig struct { // EventFD is a file descriptor for the event that is signaled when // data is becomes available in this queue. - EventFD int + EventFD eventfd.Eventfd // TxPipeFD is a file descriptor for the tx pipe associated with the // queue. @@ -63,16 +65,97 @@ type QueueConfig struct { SharedDataFD int } +// FDs returns the FD's in the QueueConfig as a slice of ints. This must +// be used in conjunction with QueueConfigFromFDs to ensure the order +// of FDs matches when reconstructing the config when serialized or sent +// as part of control messages. +func (q *QueueConfig) FDs() []int { + return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD} +} + +// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each +// entry represents an file descriptor. The order of FDs in the slice must be in +// the order specified below for the config to be valid. QueueConfig.FDs() +// should be used when the config needs to be serialized or sent as part of a +// control message to ensure the correct order. +func QueueConfigFromFDs(fds []int) (QueueConfig, error) { + if len(fds) != 5 { + return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds)) + } + return QueueConfig{ + DataFD: fds[0], + EventFD: eventfd.Wrap(fds[1]), + TxPipeFD: fds[2], + RxPipeFD: fds[3], + SharedDataFD: fds[4], + }, nil +} + +// Options specify the details about the sharedmem endpoint to be created. +type Options struct { + // MTU is the mtu to use for this endpoint. + MTU uint32 + + // BufferSize is the size of each scatter/gather buffer that will hold packet + // data. + // + // NOTE: This directly determines number of packets that can be held in + // the ring buffer at any time. This does not have to be sized to the MTU as + // the shared memory queue design allows usage of more than one buffer to be + // used to make up a given packet. + BufferSize uint32 + + // LinkAddress is the link address for this endpoint (required). + LinkAddress tcpip.LinkAddress + + // TX is the transmit queue configuration for this shared memory endpoint. + TX QueueConfig + + // RX is the receive queue configuration for this shared memory endpoint. + RX QueueConfig + + // PeerFD is the fd for the connected peer which can be used to detect + // peer disconnects. + PeerFD int + + // OnClosed is a function that is called when the endpoint is being closed + // (probably due to peer going away) + OnClosed func(err tcpip.Error) + + // TXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityTXChecksumOffload. + TXChecksumOffload bool + + // RXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityRXChecksumOffload. + RXChecksumOffload bool +} + type endpoint struct { // mtu (maximum transmission unit) is the maximum size of a packet. + // mtu is immutable. mtu uint32 // bufferSize is the size of each individual buffer. + // bufferSize is immutable. bufferSize uint32 // addr is the local address of this endpoint. + // addr is immutable. addr tcpip.LinkAddress + // peerFD is an fd to the peer that can be used to detect when the + // peer is gone. + // peerFD is immutable. + peerFD int + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // hdrSize is the size of the link layer header if any. + // hdrSize is immutable. + hdrSize uint32 + // rx is the receive queue. rx rx @@ -83,34 +166,55 @@ type endpoint struct { // Wait group used to indicate that all workers have stopped. completed sync.WaitGroup + // onClosed is a function to be called when the FD's peer (if any) closes + // its end of the communication pipe. + onClosed func(tcpip.Error) + // mu protects the following fields. mu sync.Mutex // tx is the transmit queue. + // +checklocks:mu tx tx // workerStarted specifies whether the worker goroutine was started. + // +checklocks:mu workerStarted bool } // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { +func New(opts Options) (stack.LinkEndpoint, error) { e := &endpoint{ - mtu: mtu, - bufferSize: bufferSize, - addr: addr, + mtu: opts.MTU, + bufferSize: opts.BufferSize, + addr: opts.LinkAddress, + peerFD: opts.PeerFD, + onClosed: opts.OnClosed, } - if err := e.tx.init(bufferSize, &tx); err != nil { + if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil { return nil, err } - if err := e.rx.init(bufferSize, &rx); err != nil { + if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil { e.tx.cleanup() return nil, err } + e.caps = stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + e.caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + e.caps |= stack.CapabilityTXChecksumOffload + } + + if opts.LinkAddress != "" { + e.hdrSize = header.EthernetMinimumSize + e.caps |= stack.CapabilityResolutionRequired + } return e, nil } @@ -119,13 +223,13 @@ func (e *endpoint) Close() { // Tell dispatch goroutine to stop, then write to the eventfd so that // it wakes up in case it's sleeping. atomic.StoreUint32(&e.stopRequested, 1) - unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + e.rx.eventFD.Notify() // Cleanup the queues inline if the worker hasn't started yet; we also // know it won't start from now on because stopRequested is set to 1. e.mu.Lock() + defer e.mu.Unlock() workerPresent := e.workerStarted - e.mu.Unlock() if !workerPresent { e.tx.cleanup() @@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { e.workerStarted = true e.completed.Add(1) + + // Spin up a goroutine to monitor for peer shutdown. + if e.peerFD >= 0 { + e.completed.Add(1) + go func() { + defer e.completed.Done() + b := make([]byte, 1) + // When sharedmem endpoint is in use the peerFD is never used for any data + // transfer and this Read should only return if the peer is shutting down. + _, err := rawfile.BlockingRead(e.peerFD, b) + if e.onClosed != nil { + e.onClosed(err) + } + }() + } + // Link endpoints are not savable. When transportation endpoints // are saved, they stop sending outgoing packets and all // incoming packets are rejected. @@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool { // MTU implements stack.LinkEndpoint.MTU. It returns the value initialized // during construction. func (e *endpoint) MTU() uint32 { - return e.mtu - header.EthernetMinimumSize + return e.mtu - e.hdrSize } // Capabilities implements stack.LinkEndpoint.Capabilities. -func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps } // MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the // ethernet frame header size. -func (*endpoint) MaxHeaderLength() uint16 { - return header.EthernetMinimumSize +func (e *endpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) } // LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local @@ -205,17 +325,15 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WriteRawPacket implements stack.LinkEndpoint. func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) +// +checklocks:e.mu +func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if e.addr != "" { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } views := pkt.Views() // Transmit the packet. - e.mu.Lock() ok := e.tx.transmit(views...) - e.mu.Unlock() - if !ok { return &tcpip.ErrWouldBlock{} } @@ -223,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol return nil } +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + return err + } + e.tx.notify() + return nil +} + // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") +func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + n := 0 + var err tcpip.Error + e.mu.Lock() + defer e.mu.Unlock() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + break + } + n++ + } + // WritePackets never returns an error if it successfully transmitted at least + // one packet. + if err != nil && n == 0 { + return 0, err + } + e.tx.notify() + return n, nil } // dispatchLoop reads packets from the rx queue in a loop and dispatches them @@ -268,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { Data: buffer.View(b).ToVectorisedView(), }) - hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) - if !ok { - continue + var src, dst tcpip.LinkAddress + var proto tcpip.NetworkProtocolNumber + if e.addr != "" { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + continue + } + eth := header.Ethernet(hdr) + src = eth.SourceAddress() + dst = eth.DestinationAddress() + proto = eth.Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + continue + } + switch header.IPVersion(h) { + case header.IPv4Version: + proto = header.IPv4ProtocolNumber + case header.IPv6Version: + proto = header.IPv6ProtocolNumber + default: + continue + } } - eth := header.Ethernet(hdr) // Send packet up the stack. - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) + d.DeliverNetworkPacket(src, dst, proto, pkt) } + e.mu.Lock() + defer e.mu.Unlock() + // Clean state. e.tx.cleanup() e.rx.cleanup() diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go new file mode 100644 index 000000000..43c5b8c63 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go @@ -0,0 +1,344 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type serverEndpoint struct { + // mtu (maximum transmission unit) is the maximum size of a packet. + // mtu is immutable. + mtu uint32 + + // bufferSize is the size of each individual buffer. + // bufferSize is immutable. + bufferSize uint32 + + // addr is the local address of this endpoint. + // addr is immutable + addr tcpip.LinkAddress + + // rx is the receive queue. + rx serverRx + + // stopRequested is to be accessed atomically only, and determines if the + // worker goroutines should stop. + stopRequested uint32 + + // Wait group used to indicate that all workers have stopped. + completed sync.WaitGroup + + // peerFD is an fd to the peer that can be used to detect when the peer is + // gone. + // peerFD is immutable. + peerFD int + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // hdrSize is the size of the link layer header if any. + // hdrSize is immutable. + hdrSize uint32 + + // onClosed is a function to be called when the FD's peer (if any) closes its + // end of the communication pipe. + onClosed func(tcpip.Error) + + // mu protects the following fields. + mu sync.Mutex + + // tx is the transmit queue. + // +checklocks:mu + tx serverTx + + // workerStarted specifies whether the worker goroutine was started. + // +checklocks:mu + workerStarted bool +} + +// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be +// broken up into buffers of "bufferSize" bytes. +func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) { + e := &serverEndpoint{ + mtu: opts.MTU, + bufferSize: opts.BufferSize, + addr: opts.LinkAddress, + peerFD: opts.PeerFD, + onClosed: opts.OnClosed, + } + + if err := e.tx.init(&opts.RX); err != nil { + return nil, err + } + + if err := e.rx.init(&opts.TX); err != nil { + e.tx.cleanup() + return nil, err + } + + e.caps = stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + e.caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + e.caps |= stack.CapabilityTXChecksumOffload + } + + if opts.LinkAddress != "" { + e.hdrSize = header.EthernetMinimumSize + e.caps |= stack.CapabilityResolutionRequired + } + + return e, nil +} + +// Close frees all resources associated with the endpoint. +func (e *serverEndpoint) Close() { + // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes + // up in case it's sleeping. + atomic.StoreUint32(&e.stopRequested, 1) + e.rx.eventFD.Notify() + + // Cleanup the queues inline if the worker hasn't started yet; we also know it + // won't start from now on because stopRequested is set to 1. + e.mu.Lock() + defer e.mu.Unlock() + workerPresent := e.workerStarted + + if !workerPresent { + e.tx.cleanup() + e.rx.cleanup() + } +} + +// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have +// stopped after a Close() call. +func (e *serverEndpoint) Wait() { + e.completed.Wait() +} + +// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that +// reads packets from the rx queue. +func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { + e.workerStarted = true + e.completed.Add(1) + if e.peerFD >= 0 { + e.completed.Add(1) + // Spin up a goroutine to monitor for peer shutdown. + go func() { + b := make([]byte, 1) + // When sharedmem endpoint is in use the peerFD is never used for any + // data transfer and this Read should only return if the peer is + // shutting down. + _, err := rawfile.BlockingRead(e.peerFD, b) + if e.onClosed != nil { + e.onClosed(err) + } + e.completed.Done() + }() + } + // Link endpoints are not savable. When transportation endpoints are saved, + // they stop sending outgoing packets and all incoming packets are rejected. + go e.dispatchLoop(dispatcher) // S/R-SAFE: see above. + } + e.mu.Unlock() +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *serverEndpoint) IsAttached() bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.workerStarted +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *serverEndpoint) MTU() uint32 { + return e.mtu - e.hdrSize +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the +// ethernet frame header size. +func (e *serverEndpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local +// link address. +func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress { + return e.addr +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + ethHdr := &header.EthernetFields{ + DstAddr: remote, + Type: protocol, + } + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket +func (e *serverEndpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + views := pkt.Views() + e.mu.Lock() + defer e.mu.Unlock() + ok := e.tx.transmit(views) + if !ok { + return &tcpip.ErrWouldBlock{} + } + e.tx.notify() + return nil +} + +// +checklocks:e.mu +func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if e.addr != "" { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } + + views := pkt.Views() + ok := e.tx.transmit(views) + if !ok { + return &tcpip.ErrWouldBlock{} + } + + return nil +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *serverEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + // Transmit the packet. + e.mu.Lock() + defer e.mu.Unlock() + if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + return err + } + e.tx.notify() + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *serverEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + n := 0 + var err tcpip.Error + e.mu.Lock() + defer e.mu.Unlock() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + break + } + n++ + } + // WritePackets never returns an error if it successfully transmitted at least + // one packet. + if err != nil && n == 0 { + return 0, err + } + e.tx.notify() + return n, nil +} + +// dispatchLoop reads packets from the rx queue in a loop and dispatches them +// to the network stack. +func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) { + for atomic.LoadUint32(&e.stopRequested) == 0 { + b := e.rx.receive() + if b == nil { + e.rx.waitForPackets() + continue + } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + var src, dst tcpip.LinkAddress + var proto tcpip.NetworkProtocolNumber + if e.addr != "" { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + continue + } + eth := header.Ethernet(hdr) + src = eth.SourceAddress() + dst = eth.DestinationAddress() + proto = eth.Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + continue + } + switch header.IPVersion(h) { + case header.IPv4Version: + proto = header.IPv4ProtocolNumber + case header.IPv6Version: + proto = header.IPv6ProtocolNumber + default: + continue + } + } + // Send packet up the stack. + d.DeliverNetworkPacket(src, dst, proto, pkt) + } + + e.mu.Lock() + defer e.mu.Unlock() + + // Clean state. + e.tx.cleanup() + e.rx.cleanup() + + e.completed.Done() +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go new file mode 100644 index 000000000..1bc58614e --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go @@ -0,0 +1,220 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem_server_test + +import ( + "fmt" + "io" + "net" + "net/http" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +const ( + localLinkAddr = "\xde\xad\xbe\xef\x56\x78" + remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" + localIPv4Address = tcpip.Address("\x0a\x00\x00\x01") + remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02") + serverPort = 10001 + + defaultMTU = 1500 + defaultBufferSize = 1500 +) + +type stackOptions struct { + ep stack.LinkEndpoint + addr tcpip.Address +} + +func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) { + st := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + AllowExternalLoopbackTraffic: true, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + AllowExternalLoopbackTraffic: true, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, + }) + nicID := tcpip.NICID(1) + sniffEP := sniffer.New(stackOpts.ep) + opts := stack.NICOptions{Name: "eth0"} + if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil { + return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err) + } + + // Add Protocol Address. + protocolNum := ipv4.ProtocolNumber + routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}} + if len(stackOpts.addr) == 16 { + routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}} + protocolNum = ipv6.ProtocolNumber + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: protocolNum, + AddressWithPrefix: stackOpts.addr.WithPrefix(), + } + if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err) + } + + // Setup route table. + st.SetRouteTable(routeTable) + + return st, nil +} + +func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) { + ep, err := sharedmem.New(sharedmem.Options{ + MTU: defaultMTU, + BufferSize: defaultBufferSize, + LinkAddress: localLinkAddr, + TX: qPair.TXQueueConfig(), + RX: qPair.RXQueueConfig(), + PeerFD: peerFD, + }) + if err != nil { + return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err) + } + st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address}) + if err != nil { + return nil, fmt.Errorf("failed to create client stack: %s", err) + } + return st, nil +} + +func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) { + ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{ + MTU: defaultMTU, + BufferSize: defaultBufferSize, + LinkAddress: remoteLinkAddr, + TX: qPair.TXQueueConfig(), + RX: qPair.RXQueueConfig(), + PeerFD: peerFD, + }) + if err != nil { + return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err) + } + st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address}) + if err != nil { + return nil, fmt.Errorf("failed to create client stack: %s", err) + } + return st, nil +} + +type testContext struct { + clientStk *stack.Stack + serverStk *stack.Stack + peerFDs [2]int +} + +func newTestContext(t *testing.T) *testContext { + peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0) + if err != nil { + t.Fatalf("failed to create peerFDs: %s", err) + } + q, err := sharedmem.NewQueuePair() + if err != nil { + t.Fatalf("failed to create sharedmem queue: %s", err) + } + clientStack, err := newClientStack(t, q, peerFDs[0]) + if err != nil { + q.Close() + unix.Close(peerFDs[0]) + unix.Close(peerFDs[1]) + t.Fatalf("failed to create client stack: %s", err) + } + serverStack, err := newServerStack(t, q, peerFDs[1]) + if err != nil { + q.Close() + unix.Close(peerFDs[0]) + unix.Close(peerFDs[1]) + clientStack.Close() + t.Fatalf("failed to create server stack: %s", err) + } + return &testContext{ + clientStk: clientStack, + serverStk: serverStack, + peerFDs: peerFDs, + } +} + +func (ctx *testContext) cleanup() { + unix.Close(ctx.peerFDs[0]) + unix.Close(ctx.peerFDs[1]) + ctx.clientStk.Close() + ctx.serverStk.Close() +} + +func TestServerRoundTrip(t *testing.T) { + ctx := newTestContext(t) + defer ctx.cleanup() + listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort} + l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber) + if err != nil { + t.Fatalf("failed to start TCP Listener: %s", err) + } + defer l.Close() + var responseString = "response" + go func() { + http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(responseString)) + })) + }() + + dialFunc := func(address, protocol string) (net.Conn, error) { + return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + Dial: dialFunc, + }, + } + serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort) + response, err := httpClient.Get(serverURL) + if err != nil { + t.Fatalf("httpClient.Get(\"/\") failed: %s", err) + } + if got, want := response.StatusCode, http.StatusOK; got != want { + t.Fatalf("unexpected status code got: %d, want: %d", got, want) + } + body, err := io.ReadAll(response.Body) + if err != nil { + t.Fatalf("io.ReadAll(response.Body) failed: %s", err) + } + response.Body.Close() + if got, want := string(body), responseString; got != want { + t.Fatalf("unexpected response got: %s, want: %s", got, want) + } +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index d6d953085..a49f5f87d 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -19,9 +19,7 @@ package sharedmem import ( "bytes" - "io/ioutil" "math/rand" - "os" "strings" "testing" "time" @@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress t: t, packetCh: make(chan struct{}, 1000000), } - c.txCfg = createQueueFDs(t, queueSizes{ + c.txCfg, err = createQueueFDs(queueSizes{ dataSize: queueDataSize, txPipeSize: queuePipeSize, rxPipeSize: queuePipeSize, sharedDataSize: 4096, }) - - c.rxCfg = createQueueFDs(t, queueSizes{ + if err != nil { + t.Fatalf("createQueueFDs for tx failed: %s", err) + } + c.rxCfg, err = createQueueFDs(queueSizes{ dataSize: queueDataSize, txPipeSize: queuePipeSize, rxPipeSize: queuePipeSize, sharedDataSize: 4096, }) + if err != nil { + t.Fatalf("createQueueFDs for rx failed: %s", err) + } initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(Options{ + MTU: mtu, + BufferSize: bufferSize, + LinkAddress: addr, + TX: c.txCfg, + RX: c.rxCfg, + PeerFD: -1, + }) if err != nil { t.Fatalf("New failed: %v", err) } @@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip. func (c *testContext) cleanup() { c.ep.Close() - closeFDs(&c.txCfg) - closeFDs(&c.rxCfg) + closeFDs(c.txCfg) + closeFDs(c.rxCfg) c.txq.cleanup() c.rxq.cleanup() } @@ -191,69 +201,6 @@ func shuffle(b []int) { } } -func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir, ok := os.LookupEnv("TEST_TMPDIR") - if !ok { - tmpDir = os.Getenv("TMPDIR") - } - f, err := ioutil.TempFile(tmpDir, "sharedmem_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - unix.Unlink(f.Name()) - - if initQueue { - // Write the "slot-free" flag in the initial queue. - _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) - if err != nil { - t.Fatalf("WriteAt failed: %v", err) - } - } - - fd, err := unix.Dup(int(f.Fd())) - if err != nil { - t.Fatalf("Dup failed: %v", err) - } - - if err := unix.Ftruncate(fd, size); err != nil { - unix.Close(fd) - t.Fatalf("Ftruncate failed: %v", err) - } - - return fd -} - -func closeFDs(c *QueueConfig) { - unix.Close(c.DataFD) - unix.Close(c.EventFD) - unix.Close(c.TxPipeFD) - unix.Close(c.RxPipeFD) - unix.Close(c.SharedDataFD) -} - -type queueSizes struct { - dataSize int64 - txPipeSize int64 - rxPipeSize int64 - sharedDataSize int64 -} - -func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { - fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0) - if err != 0 { - t.Fatalf("eventfd failed: %v", error(err)) - } - - return QueueConfig{ - EventFD: int(fd), - DataFD: createFile(t, s.dataSize, false), - TxPipeFD: createFile(t, s.txPipeSize, true), - RxPipeFD: createFile(t, s.rxPipeSize, true), - SharedDataFD: createFile(t, s.sharedDataSize, false), - } -} - // TestSimpleSend sends 1000 packets with random header and payload sizes, // then checks that the right payload is received on the shared memory queues. func TestSimpleSend(t *testing.T) { @@ -263,6 +210,7 @@ func TestSimpleSend(t *testing.T) { // Prepare route. var r stack.RouteInfo r.RemoteLinkAddress = remoteLinkAddr + r.LocalLinkAddress = localLinkAddr for iters := 1000; iters > 0; iters-- { func() { @@ -280,8 +228,11 @@ func TestSimpleSend(t *testing.T) { Data: data.ToVectorisedView(), }) copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + // Every PacketBuffer must have these set: + // See nic.writePacket. + pkt.EgressRoute = r + pkt.NetworkProtocolNumber = proto if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -350,8 +301,11 @@ func TestPreserveSrcAddressInSend(t *testing.T) { // the minimum size of the ethernet header. ReserveHeaderBytes: header.EthernetMinimumSize, }) - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + // Every PacketBuffer must have these set: + // See nic.writePacket. + pkt.EgressRoute = r + pkt.NetworkProtocolNumber = proto if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -672,7 +626,7 @@ func TestSimpleReceive(t *testing.T) { // Push completion. c.pushRxCompletion(uint32(len(contents)), bufs) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") @@ -718,7 +672,7 @@ func TestRxBuffersReposted(t *testing.T) { // Complete the buffer. c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for it to be reposted. bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) @@ -734,7 +688,7 @@ func TestRxBuffersReposted(t *testing.T) { // Complete with two buffers. c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for them to be reposted. for j := 0; j < 2; j++ { @@ -759,7 +713,7 @@ func TestReceivePostingIsFull(t *testing.T) { first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that packet is received. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") @@ -768,7 +722,7 @@ func TestReceivePostingIsFull(t *testing.T) { second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that no packet is received yet, as the worker is blocked trying // to repost. @@ -781,7 +735,7 @@ func TestReceivePostingIsFull(t *testing.T) { // Flush tx queue, which will allow the first buffer to be reposted, // and the second completion to be pulled. c.rxq.tx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that second packet completes. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") @@ -803,7 +757,7 @@ func TestCloseWhileWaitingToPost(t *testing.T) { bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for packet to be indicated. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go index f7e816a41..d974c266e 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go @@ -15,7 +15,12 @@ package sharedmem import ( + "fmt" + "reflect" "unsafe" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/memutil" ) // sharedDataPointer converts the shared data slice into a pointer so that it @@ -23,3 +28,31 @@ import ( func sharedDataPointer(sharedData []byte) *uint32 { return (*uint32)(unsafe.Pointer(&sharedData[0:4][0])) } + +// getBuffer returns a memory region mapped to the full contents of the given +// file descriptor. +func getBuffer(fd int) ([]byte, error) { + var s unix.Stat_t + if err := unix.Fstat(fd, &s); err != nil { + return nil, err + } + + // Check that size doesn't overflow an int. + if s.Size > int64(^uint(0)>>1) { + return nil, unix.EDOM + } + + addr, err := memutil.MapFile(0 /* addr */, uintptr(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE, uintptr(fd), 0 /*offset*/) + if err != nil { + return nil, fmt.Errorf("failed to map memory for buffer fd: %d, error: %s", fd, err) + } + + // Use unsafe to conver addr into a []byte. + var b []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + hdr.Data = addr + hdr.Len = int(s.Size) + hdr.Cap = int(s.Size) + + return b, nil +} diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index e3210051f..d6c61afee 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -28,10 +29,12 @@ const ( // tx holds all state associated with a tx queue. type tx struct { - data []byte - q queue.Tx - ids idManager - bufs bufferManager + data []byte + q queue.Tx + ids idManager + bufs bufferManager + eventFD eventfd.Eventfd + sharedDataFD int } // init initializes all state needed by the tx queue based on the information @@ -64,7 +67,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error { t.ids.init() t.bufs.init(0, len(data), int(mtu)) t.data = data - + t.eventFD = c.EventFD + t.sharedDataFD = c.SharedDataFD return nil } @@ -142,20 +146,10 @@ func (t *tx) transmit(bufs ...buffer.View) bool { return true } -// getBuffer returns a memory region mapped to the full contents of the given -// file descriptor. -func getBuffer(fd int) ([]byte, error) { - var s unix.Stat_t - if err := unix.Fstat(fd, &s); err != nil { - return nil, err - } - - // Check that size doesn't overflow an int. - if s.Size > int64(^uint(0)>>1) { - return nil, unix.EDOM - } - - return unix.Mmap(fd, 0, int(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE) +// notify writes to the tx.eventFD to indicate to the peer that there is data to +// be read. +func (t *tx) notify() { + t.eventFD.Notify() } // idDescriptor is used by idManager to either point to a tx buffer (in case diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 6515c31e5..e08243547 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -272,7 +272,6 @@ type protocol struct { func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } func (p *protocol) MinimumPacketSize() int { return header.ARPSize } -func (p *protocol) DefaultPrefixLen() int { return 0 } func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 5fcbfeaa2..061cc35ae 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext t.Fatalf("CreateNIC failed: %s", err) } - if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %s", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: stackAddr.WithPrefix(), + } + if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } tc.s.SetRouteTable([]tcpip.Route{{ @@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) { } if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: test.nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 2179302d3..87f650661 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -233,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv4.ProtocolNumber, local) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: local.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, err + } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, @@ -249,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv6.ProtocolNumber, local) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: local.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, err + } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, @@ -272,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c } v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) + if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err) } v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) + if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err) } return s, e @@ -713,8 +725,8 @@ func TestReceive(t *testing.T) { if !ok { t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) } - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err) } else { ep.DecRef() } @@ -885,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -971,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -1237,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv6Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -1304,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name string protoFactory stack.NetworkProtocolFactory protoNum tcpip.NetworkProtocolNumber - nicAddr tcpip.Address + nicAddr tcpip.AddressWithPrefix remoteAddr tcpip.Address pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) @@ -1314,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) @@ -1355,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with IHL too small", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) @@ -1379,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 too small", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) @@ -1397,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 minimum size", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) @@ -1433,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with options", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) @@ -1478,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with options and data across views", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) @@ -1519,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) @@ -1559,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 with extension header", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) @@ -1604,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 minimum size", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) @@ -1639,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 too small", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) @@ -1663,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }{ { name: "unspecified source", - srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))), }, { name: "random source", - srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))), }, } @@ -1680,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.protoNum, + AddressWithPrefix: test.nicAddr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) - r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err) } defer r.Release() @@ -2072,8 +2088,12 @@ func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddressWithPrefix(nicID, test.proto, test.addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, test.proto, test.addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.proto, + AddressWithPrefix: test.addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 2aa38eb98..3eff0bbd8 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -167,23 +167,22 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet p := hdr.TransportProtocol() dstAddr := hdr.DestinationAddress() // Skip the ip header, then deliver the error. - pkt.Data().DeleteFront(hlen) + if _, ok := pkt.Data().Consume(hlen); !ok { + panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen)) + } e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived - // ICMP packets don't have their TransportHeader fields set. See - // icmp/protocol.go:protocol.Parse for a full explanation. - v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) - if !ok { + h := header.ICMPv4(pkt.TransportHeader().View()) + if len(h) < header.ICMPv4MinimumSize { received.invalid.Increment() return } - h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. - if pkt.Data().AsRange().Checksum() != 0xffff { + if header.Checksum(h, pkt.Data().AsRange().Checksum()) != 0xffff { received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because @@ -240,20 +239,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.echoRequest.Increment() - sent := e.stats.icmp.packetsSent - if !e.protocol.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return - } - // DeliverTransportPacket will take ownership of pkt so don't use it beyond // this point. Make a deep copy of the data before pkt gets sent as we will - // be modifying fields. + // be modifying fields. Both the ICMP header (with its type modified to + // EchoReply) and payload are reused in the reply packet. // // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. - replyData := pkt.Data().AsRange().ToOwnedView() + replyData := stack.PayloadSince(pkt.TransportHeader()) ipHdr := header.IPv4(pkt.NetworkHeader().View()) localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast @@ -281,6 +275,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } defer r.Release() + sent := e.stats.icmp.packetsSent + if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) { + sent.rateLimited.Increment() + return + } + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the // header information, we may have to change this code to handle the // ICMP header no longer being in the data buffer. @@ -331,6 +331,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.echoReply.Increment() + // ICMP sockets expect the ICMP header to be present, so we don't consume + // the ICMP header. e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: @@ -338,7 +340,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { mtu := h.MTU() code := h.Code() - pkt.Data().DeleteFront(header.ICMPv4MinimumSize) switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) @@ -562,31 +563,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { - // TODO(gvisor.dev/issue/3810): - // Unfortunately the current stack pretty much always has ICMPv4 headers - // in the Data section of the packet but there is no guarantee that is the - // case. If this is the case grab the header to make it like all other - // packet types. When this is cleaned up the Consume should be removed. - if transportHeader.IsEmpty() { - var ok bool - transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize) - if !ok { - return nil - } - } else if transportHeader.Size() < header.ICMPv4MinimumSize { - return nil - } // We need to decide to explicitly name the packets we can respond to or // the ones we can not respond to. The decision is somewhat arbitrary and // if problems arise this could be reversed. It was judged less of a breach @@ -606,6 +586,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip } } + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) { + switch reason := reason.(type) { + case *icmpReasonPortUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonProtoUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetworkUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0 + case *icmpReasonFragmentationNeeded: + return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0 + case *icmpReasonTTLExceeded: + return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0 + case *icmpReasonParamProblem: + return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + }() + + if !p.allowICMPReply(icmpType, icmpCode) { + sent.rateLimited.Increment() + return nil + } + // Now work out how much of the triggering packet we should return. // As per RFC 1812 Section 4.3.2.3 // @@ -658,44 +667,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonProtoUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetworkUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4NetUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4HostUnreachable) - counter = sent.dstUnreachable - case *icmpReasonFragmentationNeeded: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) - counter = sent.dstUnreachable - case *icmpReasonTTLExceeded: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.timeExceeded - case *icmpReasonParamProblem: - icmpHdr.SetType(header.ICMPv4ParamProblem) - icmpHdr.SetCode(header.ICMPv4UnusedCode) - icmpHdr.SetPointer(reason.pointer) - counter = sent.paramProblem - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetCode(icmpCode) + icmpHdr.SetType(icmpType) + icmpHdr.SetPointer(pointer) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 4bd6f462e..c6576fcbc 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma // cycles. func TestIGMPV1Present(t *testing.T) { e, s, clock := createStack(t, true) - addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength} - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { @@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) { // The initial set of IGMP reports that were queued should be sent once an // address is assigned. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: stackAddr, + PrefixLen: defaultPrefixLength, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if got := reportStat.Value(); got != 1 { t.Errorf("got reportStat.Value() = %d, want = 1", got) @@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) { t.Run(test.name, func(t *testing.T) { e, s, _ := createStack(t, true) for _, address := range test.stackAddresses { - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: address, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } stats := s.Stats() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e2472c851..d1d509702 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -167,6 +167,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { return nil } +func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + ep, ok := p.mu.eps[id] + return ep, ok +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -240,7 +247,7 @@ func (e *endpoint) Enable() tcpip.Error { } // Create an endpoint to receive broadcast packets on this interface. - ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { return err } @@ -419,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -432,7 +439,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // We should do this for every packet, rather than only NATted packets, but // removing this check short circuits broadcasts before they are sent out to // other hosts. - if pkt.NatDone { + if pkt.DNATDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we @@ -459,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -542,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -569,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -710,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -737,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -746,7 +753,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. - newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader())) + newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength())) + newHdr := header.IPv4(newPkt.NetworkHeader().View()) // As per RFC 791 page 30, Time to Live, // @@ -755,12 +763,19 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // Even if no local information is available on the time actually // spent, the field must be decremented by 1. newHdr.SetTTL(ttl - 1) + // We perform a full checksum as we may have updated options above. The IP + // header is relatively small so this is not expected to be an expensive + // operation. + newHdr.SetChecksum(0) + newHdr.SetChecksum(^newHdr.CalculateChecksum()) + + forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID()) + if !ok { + // The interface was removed after we obtained the route. + return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}} + } - switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(newHdr).ToVectorisedView(), - IsForwardedPacket: true, - })); err.(type) { + switch err := forwardToEp.writePacket(r, newPkt, true /* headerIncluded */); err.(type) { case nil: return nil case *tcpip.ErrMessageTooLong: @@ -826,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -925,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -969,7 +984,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, } proto := h.Protocol() - resPkt, _, ready, err := e.protocol.fragmentation.Process( + resPkt, transProtoNum, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -1000,6 +1015,8 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, h.SetTotalLength(uint16(pkt.Data().Size() + len(h))) h.SetFlagsFragmentOffset(0, 0) + e.protocol.parseTransport(pkt, tcpip.TransportProtocolNumber(transProtoNum)) + // Now that the packet is reassembled, it can be sent to raw sockets. e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } @@ -1075,11 +1092,11 @@ func (e *endpoint) Close() { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties) if err == nil { e.mu.igmp.sendQueuedReports() } @@ -1200,6 +1217,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv4Type]struct{} } // defaultTTL is the current default TTL for the protocol. Only the @@ -1226,11 +1246,6 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv4MinimumSize } -// DefaultPrefixLen returns the IPv4 default prefix length. -func (p *protocol) DefaultPrefixLen() int { - return header.IPv4AddressSize * 8 -} - // ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv4(v) @@ -1297,19 +1312,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) } if hasTransportHdr { - switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { - case stack.ParsedOK: - case stack.UnknownTransportProtocol, stack.TransportLayerParseError: - // The transport layer will handle unknown protocols and transport layer - // parsing errors. - default: - panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) - } + p.parseTransport(pkt, transProtoNum) } return h, true } +func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) { + if transProtoNum == header.ICMPv4ProtocolNumber { + // The transport layer will handle transport layer parsing errors. + _ = parse.ICMPv4(pkt) + return + } + + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } +} + // Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { @@ -1320,6 +1345,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type and code may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool { + // Mimic linux and never rate limit for PMTU discovery. + // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288 + if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded { + return true + } + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { @@ -1399,6 +1441,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) + // Set ICMP rate limiting to Linux defaults. + // See https://man7.org/linux/man-pages/man7/icmp.7.html. + p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{ + header.ICMPv4DstUnreachable: struct{}{}, + header.ICMPv4SrcQuench: struct{}{}, + header.ICMPv4TimeExceeded: struct{}{}, + header.ICMPv4ParamProblem: struct{}{}, + } return p } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 73407be67..ef91245d7 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) { defer ep.Close() // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if err := ep.Connect(randomAddr); err != nil { t.Errorf("Connect failed: %v", err) @@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err) } expectedEmittedPacketCount := 1 @@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) } // Default routes for IPv4 so ICMP can find a route to the remote @@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x02") tos = 0 ident = 1 ttl = 48 @@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } for _, f := range test.fragments { @@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x02") tos = 0 ident = 1 ttl = 48 @@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) { const ( nicID = 1 - addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 + addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1 + addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2 + addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3 ) // Build and return a UDP header containing payload. @@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } wq := waiter.Queue{} @@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" + src = tcpip.Address("\x10\x00\x00\x01") + dst = tcpip.Address("\x10\x00\x00\x02") ) - if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { mask := tcpip.AddressMask(header.IPv4Broadcast) @@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) } s.SetRouteTable([]tcpip.Route{ @@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ @@ -3349,3 +3373,139 @@ func TestCloseLocking(t *testing.T) { } }() } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv4ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.UDPProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index f99cbf8f3..f814926a3 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -51,6 +51,7 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 94caaae6c..adfc8d8da 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().DeleteFront(header.IPv6MinimumSize) + if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok { + panic("could not consume IPv6MinimumSize bytes") + } if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) + if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok { + panic("could not consume IPv6FragmentHeaderSize bytes") + } } e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) @@ -270,7 +274,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if routerAlert == nil || routerAlert.Value != header.IPv6RouterAlertMLD { return false } - if pkt.Data().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { + if pkt.TransportHeader().View().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { return false } if iph.HopLimit() != header.MLDHopLimit { @@ -285,20 +289,17 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) { sent := e.stats.icmp.packetsSent received := e.stats.icmp.packetsReceived - // ICMP packets don't have their TransportHeader fields set. See - // icmp/protocol.go:protocol.Parse for a full explanation. - v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) - if !ok { + h := header.ICMPv6(pkt.TransportHeader().View()) + if len(h) < header.ICMPv6MinimumSize { received.invalid.Increment() return } - h := header.ICMPv6(v) iph := header.IPv6(pkt.NetworkHeader().View()) srcAddr := iph.SourceAddress() dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. - payload := pkt.Data().AsRange().SubRange(len(h)) + payload := pkt.Data().AsRange() if got, want := h.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: h, Src: srcAddr, @@ -325,28 +326,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.packetTooBig.Increment() - hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { - received.invalid.Increment() - return - } - networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + networkMTU, err := calculateNetworkMTU(h.MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 } - pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() - hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { - received.invalid.Increment() - return - } - code := header.ICMPv6(hdr).Code() - pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) - switch code { + switch h.Code() { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: @@ -354,16 +342,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } case header.ICMPv6NeighborSolicit: received.neighborSolicit.Increment() - if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborSolicitMinimumSize { + if !isNDPValid() || len(h) < header.ICMPv6NeighborSolicitMinimumSize { received.invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.AsView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and AsView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.AsView()) + ns := header.NDPNeighborSolicit(h.MessageBody()) targetAddr := ns.TargetAddress() // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast @@ -576,16 +560,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6NeighborAdvert: received.neighborAdvert.Increment() - if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborAdvertMinimumSize { + if !isNDPValid() || len(h) < header.ICMPv6NeighborAdvertMinimumSize { received.invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.AsView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and AsView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.AsView()) + na := header.NDPNeighborAdvert(h.MessageBody()) it, err := na.Options().Iter(false /* check */) if err != nil { @@ -672,12 +652,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoRequest: received.echoRequest.Increment() - icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) - if !ok { - received.invalid.Increment() - return - } - // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. localAddr := dstAddr @@ -692,13 +666,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } defer r.Release() + if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) { + sent.rateLimited.Increment() + return + } + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data().ExtractVV(), }) icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - copy(icmp, icmpHdr) + copy(icmp, h) icmp.SetType(header.ICMPv6EchoReply) dataRange := replyPkt.Data().AsRange() icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -720,7 +699,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoReply: received.echoReply.Increment() - if pkt.Data().Size() < header.ICMPv6EchoMinimumSize { + if len(h) < header.ICMPv6EchoMinimumSize { received.invalid.Increment() return } @@ -740,7 +719,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Solictation? - if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { + if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { received.invalid.Increment() return } @@ -750,9 +729,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and AsView() - // will not incur allocations. - rs := header.NDPRouterSolicit(payload.AsView()) + rs := header.NDPRouterSolicit(h.MessageBody()) it, err := rs.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -796,7 +773,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Advertisement? - if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { + if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { received.invalid.Increment() return } @@ -810,9 +787,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and AsView() - // will not incur allocations. - ra := header.NDPRouterAdvert(payload.AsView()) + ra := header.NDPRouterAdvert(h.MessageBody()) it, err := ra.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -890,11 +865,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType { case header.ICMPv6MulticastListenerQuery: e.mu.Lock() - e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.AsView())) + e.mu.mld.handleMulticastListenerQuery(header.MLD(h.MessageBody())) e.mu.Unlock() case header.ICMPv6MulticastListenerReport: e.mu.Lock() - e.mu.mld.handleMulticastListenerReport(header.MLD(payload.AsView())) + e.mu.mld.handleMulticastListenerReport(header.MLD(h.MessageBody())) e.mu.Unlock() case header.ICMPv6MulticastListenerDone: default: @@ -1174,28 +1149,37 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { - // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. - // Unfortunately at this time ICMP Packets do not have a transport - // header separated out. It is in the Data part so we need to - // separate it out now. We will just pretend it is a minimal length - // ICMP packet as we don't really care if any later bits of a - // larger ICMP packet are in the header view or in the Data view. - transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize) - if !ok { + if typ := header.ICMPv6(pkt.TransportHeader().View()).Type(); typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { return nil } - typ := header.ICMPv6(transport).Type() - if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { - return nil + } + + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) { + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer + case *icmpReasonPortUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0 + case *icmpReasonPacketTooBig: + return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0 + case *icmpReasonHopLimitExceeded: + return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0 + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } + }() + + if !p.allowICMPReply(icmpType) { + sent.rateLimited.Increment() + return nil } network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() @@ -1232,40 +1216,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonParameterProblem: - icmpHdr.SetType(header.ICMPv6ParamProblem) - icmpHdr.SetCode(reason.code) - icmpHdr.SetTypeSpecific(reason.pointer) - counter = sent.paramProblem - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6AddressUnreachable) - counter = sent.dstUnreachable - case *icmpReasonPacketTooBig: - icmpHdr.SetType(header.ICMPv6PacketTooBig) - icmpHdr.SetCode(header.ICMPv6UnusedCode) - counter = sent.packetTooBig - case *icmpReasonHopLimitExceeded: - icmpHdr.SetType(header.ICMPv6TimeExceeded) - icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv6TimeExceeded) - icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) - counter = sent.timeExceeded - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetType(icmpType) + icmpHdr.SetCode(icmpCode) + icmpHdr.SetTypeSpecific(typeSpecific) + dataRange := newPkt.Data().AsRange() icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 7c2a3e56b..03d9f425c 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -225,8 +226,8 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") } addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -407,8 +408,12 @@ func newTestContext(t *testing.T) *testContext { if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } - if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) + llProtocolAddr0 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err) } c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) @@ -416,8 +421,12 @@ func newTestContext(t *testing.T) *testContext { if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) + llProtocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr1.WithPrefix(), + } + if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err) } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -690,8 +699,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -883,8 +896,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -1065,8 +1082,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -1240,8 +1261,12 @@ func TestLinkAddressRequest(t *testing.T) { } if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: test.nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -1411,12 +1436,14 @@ func TestPacketQueing(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, Clock: clock, }) + // Make sure ICMP rate limiting doesn't get in our way. + s.SetICMPLimit(rate.Inf) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) + if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err) } s.SetRouteTable([]tcpip.Route{ @@ -1669,8 +1696,12 @@ func TestCallsToNeighborCache(t *testing.T) { if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } { @@ -1704,8 +1735,8 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") } addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index d4bd61748..7d3e1fd53 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -761,7 +761,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // We should do this for every packet, rather than only NATted packets, but // removing this check short circuits broadcasts before they are sent out to // other hosts. - if pkt.NatDone { + if pkt.DNATDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we @@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1024,7 +1024,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. - newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader())) + newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength())) + newHdr := header.IPv6(newPkt.NetworkHeader().View()) // As per RFC 8200 section 3, // @@ -1032,11 +1033,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) - switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(newHdr).ToVectorisedView(), - IsForwardedPacket: true, - })); err.(type) { + forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID()) + if !ok { + // The interface was removed after we obtained the route. + return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}} + } + + switch err := forwardToEp.writePacket(r, newPkt, newPkt.TransportProtocolNumber, true /* headerIncluded */); err.(type) { case nil: return nil case *tcpip.ErrMessageTooLong: @@ -1097,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -1180,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1534,27 +1537,36 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. - // Calculate the number of octets parsed from data. We want to remove all - // the data except the unparsed portion located at the end, which its size - // is extHdr.Buf.Size(). + // Calculate the number of octets parsed from data. We want to consume all + // the data except the unparsed portion located at the end, whose size is + // extHdr.Buf.Size(). trim := pkt.Data().Size() - extHdr.Buf.Size() // For unfragmented packets, extHdr still contains the transport header. - // Get rid of it. + // Consume that too. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. trim += pkt.TransportHeader().View().Size() - pkt.Data().DeleteFront(trim) + if _, ok := pkt.Data().Consume(trim); !ok { + stats.MalformedPacketsReceived.Increment() + return fmt.Errorf("could not consume %d bytes", trim) + } + + proto := tcpip.TransportProtocolNumber(extHdr.Identifier) + // If the packet was reassembled from a fragment, it will not have a + // transport header set yet. + if pkt.TransportHeader().View().IsEmpty() { + e.protocol.parseTransport(pkt, proto) + } stats.PacketsDelivered.Increment() - if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - pkt.TransportProtocolNumber = p + if proto == header.ICMPv6ProtocolNumber { e.handleICMP(pkt, hasFragmentHeader, routerAlert) } else { stats.PacketsDelivered.Increment() - switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { + switch res := e.dispatcher.DeliverTransportPacket(proto, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC 4443 section 3.1: @@ -1628,12 +1640,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. e.mu.Lock() defer e.mu.Unlock() - return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) + return e.addAndAcquirePermanentAddressLocked(addr, properties) } // addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but @@ -1643,8 +1655,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // solicited-node multicast group and start duplicate address detection. // // Precondition: e.mu must be write locked. -func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { - addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties) if err != nil { return nil, err } @@ -1987,6 +1999,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv6Type]struct{} } ids []uint32 @@ -1998,7 +2013,8 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - fragmentation *fragmentation.Fragmentation + fragmentation *fragmentation.Fragmentation + icmpRateLimiter *stack.ICMPRateLimiter } // Number returns the ipv6 protocol number. @@ -2011,11 +2027,6 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen returns the IPv6 default prefix length. -func (p *protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) @@ -2087,6 +2098,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { return nil } +func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + ep, ok := p.mu.eps[id] + return ep, ok +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -2149,19 +2167,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) } if hasTransportHdr { - switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { - case stack.ParsedOK: - case stack.UnknownTransportProtocol, stack.TransportLayerParseError: - // The transport layer will handle unknown protocols and transport layer - // parsing errors. - default: - panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) - } + p.parseTransport(pkt, transProtoNum) } return h, true } +func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) { + if transProtoNum == header.ICMPv6ProtocolNumber { + // The transport layer will handle transport layer parsing errors. + _ = parse.ICMPv6(pkt) + return + } + + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } +} + // Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) @@ -2172,6 +2200,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, @@ -2268,6 +2308,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) p.SetDefaultTTL(DefaultTTL) + // Set default ICMP rate limiting to Linux defaults. + // + // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big) + // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt. + defaultIcmpTypes := make(map[header.ICMPv6Type]struct{}) + for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ { + switch i { + case header.ICMPv6PacketTooBig: + // Do not rate limit packet too big by default. + default: + defaultIcmpTypes[i] = struct{}{} + } + } + p.mu.icmpRateLimitedTypes = defaultIcmpTypes + return p } } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index d2a23fd4f..e5286081e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -41,12 +41,12 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") // The least significant 3 bytes are the same as addr2 so both addr2 and // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" - addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02") + addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03") // Tests use the extension header identifier values as uint8 instead of // header.IPv6ExtensionHeaderIdentifier. @@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { // addr2/addr3 yet as we haven't added those addresses. test.rxf(t, s, e, addr1, snmc, 0) - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err) } // Should receive a packet destined to the solicited node address of // addr2/addr3 now that we have added added addr2. test.rxf(t, s, e, addr1, snmc, 1) - if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr3.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err) } // Should still receive a packet destined to the solicited node address of @@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: test.addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil { @@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Add a default route so that a return packet knows where to go. @@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } wq := waiter.Queue{} @@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") nicID = 1 hoplimit = 255 @@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, @@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) { func TestFragmentReassemblyTimeout(t *testing.T) { const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") nicID = 1 hoplimit = 255 @@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, @@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( - src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") ) - if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") @@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err) } outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") @@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -3341,7 +3373,8 @@ func TestForwarding(t *testing.T) { ipHeaderLength := header.IPv6MinimumSize icmpHeaderLength := header.ICMPv6MinimumSize - totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen + payloadLength := icmpHeaderLength + test.payloadLength + extHdrLen + totalLength := ipHeaderLength + payloadLength hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) @@ -3359,7 +3392,7 @@ func TestForwarding(t *testing.T) { copy(hdr.Prepend(extHdrLen), extHdrBytes) ip := header.IPv6(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength), + PayloadLength: uint16(payloadLength), TransportProtocol: transportProtocol, HopLimit: test.TTL, SrcAddr: test.sourceAddr, @@ -3489,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) { t.Error(err) } } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::2").To16()), + PrefixLen: 64, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, + Src: host2IPv6Addr.AddressWithPrefix.Address, + Dst: host1IPv6Addr.AddressWithPrefix.Address, + })) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 1, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv6ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv6MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + + // Calculate the UDP checksum and set it. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(nil, sum) + udpH.SetChecksum(^udpH.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 1, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index bc9cf6999..3e5c438d3 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) { // Note, we will still expect to send a report for the global address's // solicited node address from the unspecified address as per RFC 3590 // section 4. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + globalProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: globalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err) } reportCounter++ if got := reportStat.Value(); got != reportCounter { @@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) { // Adding a link-local address should send a report for its solicited node // address and globalMulticastAddr. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + linkLocalProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err) } if dadResolutionTime != 0 { reportCounter++ @@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 8837d66d8..938427420 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) + addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{ + PEB: stack.FirstPrimaryEndpoint, + ConfigType: configType, + Deprecated: deprecated, + }) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index f0186c64e..8297a7e10 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) @@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) @@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) @@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNASize := header.ICMPv6NeighborAdvertMinimumSize @@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) { checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}), )) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } checkDADMsg() diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 1b96b1fb8..26640b7ee 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addr := tcpip.AddressWithPrefix{ - Address: stackIPv4Addr, - PrefixLen: defaultIPv4PrefixLength, + addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: stackIPv4Addr, + PrefixLen: defaultIPv4PrefixLength, + }, + } + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(), } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, clock diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 009cab643..05b879543 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -146,8 +146,12 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Add default route. diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index c10b19aa0..a72afadda 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -124,13 +124,13 @@ func main() { log.Fatalf("Bad IP address: %v", addrName) } - var addr tcpip.Address + var addrWithPrefix tcpip.AddressWithPrefix var proto tcpip.NetworkProtocolNumber if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) + addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix() proto = ipv4.ProtocolNumber } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) + addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix() proto = ipv6.ProtocolNumber } else { log.Fatalf("Unknown IP type: %v", addrName) @@ -176,11 +176,15 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) + subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address)))) if err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 34ac62444..b0b2d0afd 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -170,10 +170,14 @@ type SocketOptions struct { // message is passed with incoming packets. receiveTClassEnabled uint32 - // receivePacketInfoEnabled is used to specify if more inforamtion is - // provided with incoming packets such as interface index and address. + // receivePacketInfoEnabled is used to specify if more information is + // provided with incoming IPv4 packets. receivePacketInfoEnabled uint32 + // receivePacketInfoEnabled is used to specify if more information is + // provided with incoming IPv6 packets. + receiveIPv6PacketInfoEnabled uint32 + // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets // being written have an IP header and the endpoint should not attach an IP // header. @@ -360,6 +364,16 @@ func (so *SocketOptions) SetReceivePacketInfo(v bool) { storeAtomicBool(&so.receivePacketInfoEnabled, v) } +// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option. +func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool { + return atomic.LoadUint32(&so.receiveIPv6PacketInfoEnabled) != 0 +} + +// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option. +func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) { + storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v) +} + // GetHeaderIncluded gets value for IP_HDRINCL option. func (so *SocketOptions) GetHeaderIncluded() bool { return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0 diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 6c42ab29b..ead36880f 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -48,7 +48,6 @@ go_library( "hook_string.go", "icmp_rate_limit.go", "iptables.go", - "iptables_state.go", "iptables_targets.go", "iptables_types.go", "neighbor_cache.go", @@ -133,6 +132,7 @@ go_test( name = "stack_test", size = "small", srcs = [ + "conntrack_test.go", "forwarding_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index ae0bb4ace..7e4b5bf74 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We now promote the address. for i, s := range a.mu.primary { if s == addrState { - switch peb { + switch properties.PEB { case CanBePrimaryEndpoint: // The address is already in the primary address list. attemptAddToPrimary = false @@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address case NeverPrimaryEndpoint: a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } break } @@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // Acquire the address before returning it. addrState.mu.refs++ - addrState.mu.deprecated = deprecated - addrState.mu.configType = configType + addrState.mu.deprecated = properties.Deprecated + addrState.mu.configType = properties.ConfigType if attemptAddToPrimary { - switch peb { + switch properties.PEB { case NeverPrimaryEndpoint: case CanBePrimaryEndpoint: a.mu.primary = append(a.mu.primary, addrState) @@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address a.mu.primary[0] = addrState } default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } } @@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() - ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */) if err != nil { // addAndAcquireAddressLocked only returns an error if the address is // already assigned but we just checked above if the address exists so we // expect no error. - panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err)) } // From https://golang.org/doc/faq#nil_error: diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 140f146f6..c55f85743 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { } { - ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { - t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err) } // We don't need the address endpoint. ep.DecRef() diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 068dab7ce..a3f403855 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -37,23 +37,9 @@ import ( // Our hash table has 16K buckets. const numBuckets = 1 << 14 -// Direction of the tuple. -type direction int - -const ( - dirOriginal direction = iota - dirReply -) - -// Manipulation type for the connection. -// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and -// DNAT at the same time. -type manipType int - const ( - manipNone manipType = iota - manipSource - manipDestination + establishedTimeout time.Duration = 5 * 24 * time.Hour + unestablishedTimeout time.Duration = 120 * time.Second ) // tuple holds a connection's identifying and manipulating data in one @@ -64,13 +50,22 @@ type tuple struct { // tupleEntry is used to build an intrusive list of tuples. tupleEntry - tupleID - // conn is the connection tracking entry this tuple belongs to. conn *conn - // direction is the direction of the tuple. - direction direction + // reply is true iff the tuple's direction is opposite that of the first + // packet seen on the connection. + reply bool + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + tupleID tupleID +} + +func (t *tuple) id() tupleID { + t.mu.RLock() + defer t.mu.RUnlock() + return t.tupleID } // tupleID uniquely identifies a connection in one direction. It currently @@ -103,50 +98,43 @@ func (ti tupleID) reply() tupleID { // // +stateify savable type conn struct { + ct *ConnTrack + // original is the tuple in original direction. It is immutable. original tuple - // reply is the tuple in reply direction. It is immutable. + // reply is the tuple in reply direction. reply tuple - // manip indicates if the packet should be manipulated. It is immutable. - // TODO(gvisor.dev/issue/5696): Support updating manipulation type. - manip manipType - - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. It is immutable. - tcbHook Hook - - // mu protects all mutable state. - mu sync.Mutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // Indicates that the connection has been finalized and may handle replies. + // + // +checklocks:mu + finalized bool + // sourceManip indicates the packet's source is manipulated. + // + // +checklocks:mu + sourceManip bool + // destinationManip indicates the packet's destination is manipulated. + // + // +checklocks:mu + destinationManip bool // tcb is TCB control block. It is used to keep track of states - // of tcp connection and is protected by mu. + // of tcp connection. + // + // +checklocks:mu tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and - // is updated by each packet on the connection. It is protected by mu. + // is updated by each packet on the connection. // - // TODO(gvisor.dev/issue/5939): do not use the ambient clock. - lastUsed time.Time `state:".(unixTime)"` -} - -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn + // +checklocks:mu + lastUsed tcpip.MonotonicTime } // timedOut returns whether the connection timed out based on its state. -func (cn *conn) timedOut(now time.Time) bool { - const establishedTimeout = 5 * 24 * time.Hour - const defaultTimeout = 120 * time.Second - cn.mu.Lock() - defer cn.mu.Unlock() +func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { + cn.mu.RLock() + defer cn.mu.RUnlock() if cn.tcb.State() == tcpconntrack.ResultAlive { // Use the same default as Linux, which doesn't delete // established connections for 5(!) days. @@ -154,22 +142,31 @@ func (cn *conn) timedOut(now time.Time) bool { } // Use the same default as Linux, which lets connections in most states // other than established remain for <= 120 seconds. - return now.Sub(cn.lastUsed) > defaultTimeout + return now.Sub(cn.lastUsed) > unestablishedTimeout } // update the connection tracking state. // -// Precondition: cn.mu must be held. -func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +// +checklocks:cn.mu +func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) { + if pkt.TransportProtocolNumber != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) - } else if hook == cn.tcbHook { - cn.tcb.UpdateStateOutbound(tcpHeader) - } else { + return + } + + if reply { cn.tcb.UpdateStateInbound(tcpHeader) + } else { + cn.tcb.UpdateStateOutbound(tcpHeader) } } @@ -194,44 +191,37 @@ type ConnTrack struct { // It is immutable. seed uint32 + // clock provides timing used to determine conntrack reapings. + clock tcpip.Clock + + mu sync.RWMutex `state:"nosave"` // mu protects the buckets slice, but not buckets' contents. Only take // the write lock if you are modifying the slice or saving for S/R. - mu sync.RWMutex `state:"nosave"` - - // buckets is protected by mu. + // + // +checklocks:mu buckets []bucket } // +stateify savable type bucket struct { - // mu protects tuples. - mu sync.Mutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu tuples tupleList } -// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid -// TCP header. -// -// Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { - netHeader := pkt.Network() - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, &tcpip.ErrUnknownProtocol{} - } - - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return tupleID{}, &tcpip.ErrUnknownProtocol{} +func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { + return tcpHeader, true + } + case header.UDPProtocolNumber: + if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { + return udpHeader, true + } } - return tupleID{ - srcAddr: netHeader.SourceAddress(), - srcPort: tcpHeader.SourcePort(), - dstAddr: netHeader.DestinationAddress(), - dstPort: tcpHeader.DestinationPort(), - transProto: netHeader.TransportProtocol(), - netProto: pkt.NetworkProtocolNumber, - }, nil + return nil, false } func (ct *ConnTrack) init() { @@ -240,278 +230,285 @@ func (ct *ConnTrack) init() { ct.buckets = make([]bucket, numBuckets) } -// connFor gets the conn for pkt if it exists, or returns nil -// if it does not. It returns an error when pkt does not contain a valid TCP -// header. -// TODO(gvisor.dev/issue/6168): Support UDP. -func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil, dirOriginal +func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { + netHeader := pkt.Network() + transportHeader, ok := getTransportHeader(pkt) + if !ok { + return nil + } + + tid := tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: transportHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: transportHeader.DestinationPort(), + transProto: pkt.TransportProtocolNumber, + netProto: pkt.NetworkProtocolNumber, } - return ct.connForTID(tid) -} -func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { - bucket := ct.bucket(tid) - now := time.Now() + bktID := ct.bucket(tid) ct.mu.RLock() - defer ct.mu.RUnlock() - ct.buckets[bucket].mu.Lock() - defer ct.buckets[bucket].mu.Unlock() - - // Iterate over the tuples in a bucket, cleaning up any unused - // connections we find. - for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { - // Clean up any timed-out connections we happen to find. - if ct.reapTupleLocked(other, bucket, now) { - // The tuple expired. - continue - } - if tid == other.tupleID { - return other.conn, other.direction - } + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + now := ct.clock.NowMonotonic() + if t := bkt.connForTID(tid, now); t != nil { + return t } - return nil, dirOriginal -} + bkt.mu.Lock() + defer bkt.mu.Unlock() -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil + // Make sure a connection wasn't added between when we last checked the + // bucket and acquired the bucket's write lock. + if t := bkt.connForTIDRLocked(tid, now); t != nil { + return t } - if hook != Prerouting && hook != Output { - return nil + + // This is the first packet we're seeing for the connection. Create an entry + // for this new connection. + conn := &conn{ + ct: ct, + original: tuple{tupleID: tid}, + reply: tuple{tupleID: tid.reply(), reply: true}, + lastUsed: now, } + conn.original.conn = conn + conn.reply.conn = conn - replyTID := tid.reply() - replyTID.srcAddr = address - replyTID.srcPort = port + // For now, we only map an entry for the packet's original tuple as NAT may be + // performed on this connection. Until the packet goes through all the hooks + // and its final address/port is known, we cannot know what the response + // packet's addresses/ports will look like. + // + // This is okay because the destination cannot send its response until it + // receives the packet; the packet will only be received once all the hooks + // have been performed. + // + // See (*conn).finalize. + bkt.tuples.PushFront(&conn.original) + return &conn.original +} - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil - } - conn = newConn(tid, replyTID, manipDestination, hook) - ct.insertConn(conn) - return conn +func (ct *ConnTrack) connForTID(tid tupleID) *tuple { + bktID := ct.bucket(tid) + + ct.mu.RLock() + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + return bkt.connForTID(tid, ct.clock.NowMonotonic()) } -func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil - } - if hook != Input && hook != Postrouting { - return nil +func (bkt *bucket) connForTID(tid tupleID, now tcpip.MonotonicTime) *tuple { + bkt.mu.RLock() + defer bkt.mu.RUnlock() + return bkt.connForTIDRLocked(tid, now) +} + +// +checklocksread:bkt.mu +func (bkt *bucket) connForTIDRLocked(tid tupleID, now tcpip.MonotonicTime) *tuple { + for other := bkt.tuples.Front(); other != nil; other = other.Next() { + if tid == other.id() && !other.conn.timedOut(now) { + return other + } } + return nil +} - replyTID := tid.reply() - replyTID.dstAddr = address - replyTID.dstPort = port +func (ct *ConnTrack) finalize(cn *conn) { + tid := cn.reply.id() + id := ct.bucket(tid) - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil + ct.mu.RLock() + bkt := &ct.buckets[id] + ct.mu.RUnlock() + + bkt.mu.Lock() + defer bkt.mu.Unlock() + + if t := bkt.connForTIDRLocked(tid, ct.clock.NowMonotonic()); t != nil { + // Another connection for the reply already exists. We can't do much about + // this so we leave the connection cn represents in a state where it can + // send packets but its responses will be mapped to some other connection. + // This may be okay if the connection only expects to send packets without + // any responses. + return } - conn = newConn(tid, replyTID, manipSource, hook) - ct.insertConn(conn) - return conn + + bkt.tuples.PushFront(&cn.reply) } -// insertConn inserts conn into the appropriate table bucket. -func (ct *ConnTrack) insertConn(conn *conn) { - // Lock the buckets in the correct order. - tupleBucket := ct.bucket(conn.original.tupleID) - replyBucket := ct.bucket(conn.reply.tupleID) - ct.mu.RLock() - defer ct.mu.RUnlock() - if tupleBucket < replyBucket { - ct.buckets[tupleBucket].mu.Lock() - ct.buckets[replyBucket].mu.Lock() - } else if tupleBucket > replyBucket { - ct.buckets[replyBucket].mu.Lock() - ct.buckets[tupleBucket].mu.Lock() - } else { - // Both tuples are in the same bucket. - ct.buckets[tupleBucket].mu.Lock() - } - - // Now that we hold the locks, ensure the tuple hasn't been inserted by - // another thread. - // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? - alreadyInserted := false - for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { - if other.tupleID == conn.original.tupleID { - alreadyInserted = true - break +func (cn *conn) finalize() { + { + cn.mu.RLock() + finalized := cn.finalized + cn.mu.RUnlock() + if finalized { + return } } - if !alreadyInserted { - // Add the tuple to the map. - ct.buckets[tupleBucket].tuples.PushFront(&conn.original) - ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + cn.mu.Lock() + finalized := cn.finalized + cn.finalized = true + cn.mu.Unlock() + if finalized { + return } - // Unlocking can happen in any order. - ct.buckets[tupleBucket].mu.Unlock() - if tupleBucket != replyBucket { - ct.buckets[replyBucket].mu.Unlock() // +checklocksforce - } + cn.ct.finalize(cn) } -// handlePacket will manipulate the port and address of the packet if the -// connection exists. Returns whether, after the packet traverses the tables, -// it should create a new entry in the table. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { - if pkt.NatDone { - return false +// performNAT setups up the connection for the specified NAT. +// +// Generally, only the first packet of a connection reaches this method; other +// other packets will be manipulated without needing to modify the connection. +func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) { + cn.performNATIfNoop(port, address, dnat) + cn.handlePacket(pkt, hook, r) +} + +func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) { + cn.mu.Lock() + defer cn.mu.Unlock() + + if cn.finalized { + return } - switch hook { - case Prerouting, Input, Output, Postrouting: - default: - return false + if dnat { + if cn.destinationManip { + return + } + cn.destinationManip = true + } else { + if cn.sourceManip { + return + } + cn.sourceManip = true } - // TODO(gvisor.dev/issue/6168): Support UDP. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + cn.reply.mu.Lock() + defer cn.reply.mu.Unlock() + + if dnat { + cn.reply.tupleID.srcAddr = address + cn.reply.tupleID.srcPort = port + } else { + cn.reply.tupleID.dstAddr = address + cn.reply.tupleID.dstPort = port + } +} + +// handlePacket attempts to handle a packet and perform NAT if the connection +// has had NAT performed on it. +// +// Returns true if the packet can skip the NAT table. +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return false } - conn, dir := ct.connFor(pkt) - // Connection not found for the packet. - if conn == nil { - // If this is the last hook in the data path for this packet (Input if - // incoming, Postrouting if outgoing), indicate that a connection should be - // inserted by the end of this hook. - return hook == Input || hook == Postrouting + fullChecksum := false + updatePseudoHeader := false + natDone := &pkt.SNATDone + dnat := false + switch hook { + case Prerouting: + // Packet came from outside the stack so it must have a checksum set + // already. + fullChecksum = true + updatePseudoHeader = true + + natDone = &pkt.DNATDone + dnat = true + case Input: + case Forward: + panic("should not handle packet in the forwarding hook") + case Output: + natDone = &pkt.DNATDone + dnat = true + fallthrough + case Postrouting: + if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + updatePseudoHeader = true + } else if rt.RequiresTXTransportChecksum() { + fullChecksum = true + updatePseudoHeader = true + } + default: + panic(fmt.Sprintf("unrecognized hook = %d", hook)) } - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return false + if *natDone { + panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt)) } // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. - var newAddr tcpip.Address - var newPort uint16 - - updateSRCFields := false - - switch hook { - case Prerouting, Output: - if conn.manip == manipDestination { - switch dir { - case dirOriginal: - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - case dirReply: - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - - updateSRCFields = true + reply := pkt.tuple.reply + tid, performManip := func() (tupleID, bool) { + cn.mu.Lock() + defer cn.mu.Unlock() + + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = cn.ct.clock.NowMonotonic() + // Update connection state. + cn.updateLocked(pkt, reply) + + var tuple *tuple + if reply { + if dnat { + if !cn.sourceManip { + return tupleID{}, false + } + } else if !cn.destinationManip { + return tupleID{}, false } - pkt.NatDone = true - } - case Input, Postrouting: - if conn.manip == manipSource { - switch dir { - case dirOriginal: - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - - updateSRCFields = true - case dirReply: - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr + + tuple = &cn.original + } else { + if dnat { + if !cn.destinationManip { + return tupleID{}, false + } + } else if !cn.sourceManip { + return tupleID{}, false } - pkt.NatDone = true + + tuple = &cn.reply } - default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) - } - if !pkt.NatDone { + + return tuple.id(), true + }() + if !performManip { return false } - fullChecksum := false - updatePseudoHeader := false - switch hook { - case Prerouting, Input: - case Output, Postrouting: - // Calculate the TCP checksum and set it. - if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { - updatePseudoHeader = true - } else if r.RequiresTXTransportChecksum() { - fullChecksum = true - updatePseudoHeader = true - } - default: - panic(fmt.Sprintf("unrecognized hook = %s", hook)) + newPort := tid.dstPort + newAddr := tid.dstAddr + if dnat { + newPort = tid.srcPort + newAddr = tid.srcAddr } rewritePacket( - netHeader, - tcpHeader, - updateSRCFields, + pkt.Network(), + transportHeader, + !dnat, fullChecksum, updatePseudoHeader, newPort, newAddr, ) - // Update the state of tcb. - conn.mu.Lock() - defer conn.mu.Unlock() - - // Mark the connection as having been used recently so it isn't reaped. - conn.lastUsed = time.Now() - // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) - - return false -} - -// maybeInsertNoop tries to insert a no-op connection entry to keep connections -// from getting clobbered when replies arrive. It only inserts if there isn't -// already a connection for pkt. -// -// This should be called after traversing iptables rules only, to ensure that -// pkt.NatDone is set correctly. -func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { - // If there were a rule applying to this packet, it would be marked - // with NatDone. - if pkt.NatDone { - return - } - - // We only track TCP connections. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { - return - } - - // This is the first packet we're seeing for the TCP connection. Insert - // the noop entry (an identity mapping) so that the response doesn't - // get NATed, breaking the connection. - tid, err := packetToTupleID(pkt) - if err != nil { - return - } - conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) - ct.insertConn(conn) + *natDone = true + return true } // bucket gets the conntrack bucket for a tupleID. @@ -555,7 +552,7 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim const minInterval = 10 * time.Millisecond const maxInterval = maxFullTraversal / fractionPerReaping - now := time.Now() + now := ct.clock.NowMonotonic() checked := 0 expired := 0 var idx int @@ -563,14 +560,20 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim defer ct.mu.RUnlock() for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { idx = (i + start) % len(ct.buckets) - ct.buckets[idx].mu.Lock() - for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + bkt := &ct.buckets[idx] + bkt.mu.Lock() + for tuple := bkt.tuples.Front(); tuple != nil; { + // reapTupleLocked updates tuple's next pointer so we grab it here. + nextTuple := tuple.Next() + checked++ - if ct.reapTupleLocked(tuple, idx, now) { + if ct.reapTupleLocked(tuple, idx, bkt, now) { expired++ } + + tuple = nextTuple } - ct.buckets[idx].mu.Unlock() + bkt.mu.Unlock() } // We already checked buckets[idx]. idx++ @@ -595,44 +598,51 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // reapTupleLocked tries to remove tuple and its reply from the table. It // returns whether the tuple's connection has timed out. // -// Preconditions: -// * ct.mu is locked for reading. -// * bucket is locked. -func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { +// Precondition: ct.mu is read locked and bkt.mu is write locked. +// +checklocksread:ct.mu +// +checklocks:bkt.mu +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { if !tuple.conn.timedOut(now) { return false } - // To maintain lock order, we can only reap these tuples if the reply - // appears later in the table. - replyBucket := ct.bucket(tuple.reply()) - if bucket > replyBucket { + // To maintain lock order, we can only reap both tuples if the reply appears + // later in the table. + replyBktID := ct.bucket(tuple.id().reply()) + tuple.conn.mu.RLock() + replyTupleInserted := tuple.conn.finalized + tuple.conn.mu.RUnlock() + if bktID > replyBktID && replyTupleInserted { return true } - // Don't re-lock if both tuples are in the same bucket. - differentBuckets := bucket != replyBucket - if differentBuckets { - ct.buckets[replyBucket].mu.Lock() + // Reap the reply. + if replyTupleInserted { + // Don't re-lock if both tuples are in the same bucket. + if bktID != replyBktID { + replyBkt := &ct.buckets[replyBktID] + replyBkt.mu.Lock() + removeConnFromBucket(replyBkt, tuple) + replyBkt.mu.Unlock() + } else { + removeConnFromBucket(bkt, tuple) + } } - // We have the buckets locked and can remove both tuples. - if tuple.direction == dirOriginal { - ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) - } else { - ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) - } - ct.buckets[bucket].tuples.Remove(tuple) + bkt.tuples.Remove(tuple) + return true +} - // Don't re-unlock if both tuples are in the same bucket. - if differentBuckets { - ct.buckets[replyBucket].mu.Unlock() // +checklocksforce +// +checklocks:b.mu +func removeConnFromBucket(b *bucket, tuple *tuple) { + if tuple.reply { + b.tuples.Remove(&tuple.conn.original) + } else { + b.tuples.Remove(&tuple.conn.reply) } - - return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -640,17 +650,22 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ srcPort: epID.LocalPort, dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, - transProto: header.TCPProtocolNumber, + transProto: transProto, netProto: netProto, } - conn, _ := ct.connForTID(tid) - if conn == nil { + t := ct.connForTID(tid) + if t == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip != manipDestination { + } + + t.conn.mu.RLock() + defer t.conn.mu.RUnlock() + if !t.conn.destinationManip { // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } - return conn.original.dstAddr, conn.original.dstPort, nil + id := t.conn.original.id() + return id.dstAddr, id.dstPort, nil } diff --git a/pkg/tcpip/stack/conntrack_test.go b/pkg/tcpip/stack/conntrack_test.go new file mode 100644 index 000000000..fb0645ed1 --- /dev/null +++ b/pkg/tcpip/stack/conntrack_test.go @@ -0,0 +1,132 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" +) + +func TestReap(t *testing.T) { + // Initialize conntrack. + clock := faketime.NewManualClock() + ct := ConnTrack{ + clock: clock, + } + ct.init() + ct.checkNumTuples(t, 0) + + // Simulate sending a SYN. This will get the connection into conntrack, but + // the connection won't be considered established. Thus the timeout for + // reaping is unestablishedTimeout. + pkt1 := genTCPPacket() + pkt1.tuple = ct.getConnOrMaybeInsertNoop(pkt1) + // We set rt.routeInfo.Loop to avoid a panic when handlePacket calls + // rt.RequiresTXTransportChecksum. + var rt Route + rt.routeInfo.Loop = PacketLoop + if pkt1.tuple.conn.handlePacket(pkt1, Output, &rt) { + t.Fatal("handlePacket() shouldn't perform any NAT") + } + ct.checkNumTuples(t, 1) + + // Travel a little into the future and send the same SYN. This should update + // lastUsed, but per #6748 didn't. + clock.Advance(unestablishedTimeout / 2) + pkt2 := genTCPPacket() + pkt2.tuple = ct.getConnOrMaybeInsertNoop(pkt2) + if pkt2.tuple.conn.handlePacket(pkt2, Output, &rt) { + t.Fatal("handlePacket() shouldn't perform any NAT") + } + ct.checkNumTuples(t, 1) + + // Travel farther into the future - enough that failing to update lastUsed + // would cause a reaping - and reap the whole table. Make sure the connection + // hasn't been reaped. + clock.Advance(unestablishedTimeout * 3 / 4) + ct.reapEverything() + ct.checkNumTuples(t, 1) + + // Travel past unestablishedTimeout to confirm the tuple is gone. + clock.Advance(unestablishedTimeout / 2) + ct.reapEverything() + ct.checkNumTuples(t, 0) +} + +// genTCPPacket returns an initialized IPv4 TCP packet. +func genTCPPacket() *PacketBuffer { + const packetLen = header.IPv4MinimumSize + header.TCPMinimumSize + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: packetLen, + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.TransportProtocolNumber = header.TCPProtocolNumber + tcpHdr := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize)) + tcpHdr.Encode(&header.TCPFields{ + SrcPort: 5555, + DstPort: 6666, + SeqNum: 7777, + AckNum: 8888, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + Checksum: 0, // Conntrack doesn't verify the checksum. + }) + ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: packetLen, + Protocol: uint8(header.TCPProtocolNumber), + SrcAddr: testutil.MustParse4("1.0.0.1"), + DstAddr: testutil.MustParse4("1.0.0.2"), + Checksum: 0, // Conntrack doesn't verify the checksum. + }) + + return pkt +} + +// checkNumTuples checks that there are exactly want tuples tracked by +// conntrack. +func (ct *ConnTrack) checkNumTuples(t *testing.T, want int) { + t.Helper() + ct.mu.RLock() + defer ct.mu.RUnlock() + + var total int + for idx := range ct.buckets { + ct.buckets[idx].mu.RLock() + total += ct.buckets[idx].tuples.Len() + ct.buckets[idx].mu.RUnlock() + } + + if total != want { + t.Fatalf("checkNumTuples: got %d, wanted %d", total, want) + } +} + +func (ct *ConnTrack) reapEverything() { + var bucket int + for { + newBucket, _ := ct.reapUnused(bucket, 0 /* ignored */) + // We started reaping at bucket 0. If the next bucket isn't after our + // current bucket, we've gone through them all. + if newBucket <= bucket { + break + } + bucket = newBucket + } +} diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index ccb69393b..c2f1f4798 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int { return fwdTestNetHeaderLen } -func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } @@ -384,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } // NIC 2 has the link address "b", and added the network address 2. @@ -397,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } nic, ok := s.nics[2] diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go index 3a20839da..99e5d2df7 100644 --- a/pkg/tcpip/stack/icmp_rate_limit.go +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -16,6 +16,7 @@ package stack import ( "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/tcpip" ) const ( @@ -31,11 +32,41 @@ const ( // ICMPRateLimiter is a global rate limiter that controls the generation of // ICMP messages generated by the stack. type ICMPRateLimiter struct { - *rate.Limiter + limiter *rate.Limiter + clock tcpip.Clock } // NewICMPRateLimiter returns a global rate limiter for controlling the rate -// at which ICMP messages are generated by the stack. -func NewICMPRateLimiter() *ICMPRateLimiter { - return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} +// at which ICMP messages are generated by the stack. The returned limiter +// does not apply limits to any ICMP types by default. +func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter { + return &ICMPRateLimiter{ + clock: clock, + limiter: rate.NewLimiter(icmpLimit, icmpBurst), + } +} + +// SetLimit sets a new Limit for the limiter. +func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) { + l.limiter.SetLimitAt(l.clock.Now(), limit) +} + +// Limit returns the maximum overall event rate. +func (l *ICMPRateLimiter) Limit() rate.Limit { + return l.limiter.Limit() +} + +// SetBurst sets a new burst size for the limiter. +func (l *ICMPRateLimiter) SetBurst(burst int) { + l.limiter.SetBurstAt(l.clock.Now(), burst) +} + +// Burst returns the maximum burst size. +func (l *ICMPRateLimiter) Burst() int { + return l.limiter.Burst() +} + +// Allow reports whether one ICMP message may be sent now. +func (l *ICMPRateLimiter) Allow() bool { + return l.limiter.AllowN(l.clock.Now(), 1) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f152c0d83..fd61387bf 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables(seed uint32) *IPTables { +func DefaultTables(seed uint32, clock tcpip.Clock) *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ NATID: { @@ -182,7 +182,8 @@ func DefaultTables(seed uint32) *IPTables { Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ - seed: seed, + seed: seed, + clock: clock, }, reaperDone: make(chan struct{}, 1), } @@ -264,33 +265,125 @@ const ( chainReturn ) -// Check runs pkt through the rules for hook. It returns true when the packet -// should continue traversing the network stack and false when it should be -// dropped. +// CheckPrerouting performs the prerouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool { + const hook = Prerouting + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt) + + return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) +} + +// CheckInput performs the input hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { + const hook = Input + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) + if t := pkt.tuple; t != nil { + t.conn.finalize() + } + pkt.tuple = nil + return ret +} + +// CheckForward performs the forward hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool { + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) +} + +// CheckOutput performs the output hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool { + const hook = Output + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt) + + return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// CheckPostrouting performs the postrouting hook on the packet. // -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { - if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool { + const hook = Postrouting + + if it.shouldSkip(pkt.NetworkProtocolNumber) { return true } + + ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName) + if t := pkt.tuple; t != nil { + t.conn.finalize() + } + pkt.tuple = nil + return ret +} + +func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool { + switch netProto { + case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber: + default: + // IPTables only supports IPv4/IPv6. + return true + } + + it.mu.RLock() + defer it.mu.RUnlock() // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. + return !it.modified +} + +// check runs pkt through the rules for hook. It returns true when the packet +// should continue traversing the network stack and false when it should be +// dropped. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool { it.mu.RLock() defer it.mu.RUnlock() - if !it.modified { - return true - } - - // Packets are manipulated only if connection and matching - // NAT rule exists. - shouldTrack := it.connections.handlePacket(pkt, hook, r) // Go through each table containing the hook. priorities := it.priorities[hook] for _, tableID := range priorities { - // If handlePacket already NATed the packet, we don't need to - // check the NAT table. - if tableID == NATID && pkt.NatDone { + if t := pkt.tuple; t != nil && tableID == NATID && t.conn.handlePacket(pkt, hook, r) { continue } var table Table @@ -300,7 +393,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -311,7 +404,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v { case RuleAccept: continue case RuleDrop: @@ -327,21 +420,6 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr } } - // If this connection should be tracked, try to add an entry for it. If - // traversing the nat table didn't end in adding an entry, - // maybeInsertNoop will add a no-op entry for the connection. This is - // needeed when establishing connections so that the SYN/ACK reply to an - // outgoing SYN is delivered to the correct endpoint rather than being - // redirected by a prerouting rule. - // - // From the iptables documentation: "If there is no rule, a `null' - // binding is created: this usually does not map the packet, but exists - // to ensure we don't map another stream over an existing one." - if shouldTrack { - it.connections.maybeInsertNoop(pkt, hook) - } - - // Every table returned Accept. return true } @@ -375,30 +453,46 @@ func (it *IPTables) startReaper(interval time.Duration) { }() } -// CheckPackets runs pkts through the rules for hook and returns a map of packets that -// should not go forward. +// CheckOutputPackets performs the output hook on the packets. // -// Preconditions: -// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// * pkt.NetworkHeader is not nil. +// Returns a map of packets that must be dropped. // -// NOTE: unlike the Check API the returned map contains packets that should be -// dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckOutput(pkt, r, outNicName) + }, true /* dnat */) +} + +// CheckPostroutingPackets performs the postrouting hook on the packets. +// +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckPostrouting(pkt, r, addressEP, outNicName) + }, false /* dnat */) +} + +func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool, dnat bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if !pkt.NatDone { - if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { - if drop == nil { - drop = make(map[*PacketBuffer]struct{}) - } - drop[pkt] = struct{}{} + natDone := &pkt.SNATDone + if dnat { + natDone = &pkt.DNATDone + } + + if ok := f(pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) } - if pkt.NatDone { - if natPkts == nil { - natPkts = make(map[*PacketBuffer]struct{}) - } - natPkts[pkt] = struct{}{} + drop[pkt] = struct{}{} + } + if *natDone { + if natPkts == nil { + natPkts = make(map[*PacketBuffer]struct{}) } + natPkts[pkt] = struct{}{} } } return drop, natPkts @@ -407,11 +501,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -428,7 +522,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -454,7 +548,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -477,16 +571,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) + return rule.Target.Action(pkt, hook, r, addressEP) } // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, &tcpip.ErrNotConnected{} } - return it.connections.originalDst(epID, netProto) + return it.connections.originalDst(epID, netProto, transProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 96cc899bb..ef515bdd2 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,10 +79,49 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleReturn, 0 } +// DNATTarget modifies the destination port/IP of packets. +type DNATTarget struct { + // The new destination address for packets. + // + // Immutable. + Addr tcpip.Address + + // The new destination port for packets. + // + // Immutable. + Port uint16 + + // NetworkProtocol is the network protocol the target is used with. + // + // Immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if rt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + rt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Prerouting, Output: + case Input, Forward, Postrouting: + panic(fmt.Sprintf("%s not supported for DNAT", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + return natAction(pkt, hook, r, rt.Port, rt.Addr, true /* dnat */) + +} + // RedirectTarget redirects the packet to this machine by modifying the // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than @@ -97,7 +136,7 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -105,18 +144,9 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r rt.NetworkProtocol, pkt.NetworkProtocolNumber)) } - // Packet is already manipulated. - if pkt.NatDone { - return RuleAccept, 0 - } - - // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { - return RuleDrop, 0 - } - // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. + var address tcpip.Address switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { @@ -125,48 +155,13 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r address = header.IPv6Loopback } case Prerouting: - // No-op, as address is already set correctly. + // addressEP is expected to be set for the prerouting hook. + address = addressEP.MainAddress().Address default: panic("redirect target is supported only on output and prerouting hooks") } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - - if hook == Output { - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - udpHeader, - false, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - rt.Port, - address, - ) - } else { - udpHeader.SetDestinationPort(rt.Port) - } - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 - } - - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, r) - } - default: - return RuleDrop, 0 - } - - return RuleAccept, 0 + return natAction(pkt, hook, r, rt.Port, address, true /* dnat */) } // SNATTarget modifies the source port/IP in the outgoing packets. @@ -179,8 +174,36 @@ type SNATTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } +func natAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) (RuleVerdict, int) { + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { + return RuleDrop, 0 + } + + t := pkt.tuple + if t == nil { + return RuleDrop, 0 + } + + // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a + // different port. + if port == 0 { + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + port = header.UDP(pkt.TransportHeader().View()).SourcePort() + case header.TCPProtocolNumber: + port = header.TCP(pkt.TransportHeader().View()).SourcePort() + default: + panic(fmt.Sprintf("unsupported transport protocol = %d", pkt.TransportProtocolNumber)) + } + } + + t.conn.performNAT(pkt, hook, r, port, address, dnat) + return RuleAccept, 0 +} + // Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if st.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -188,16 +211,6 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou st.NetworkProtocol, pkt.NetworkProtocolNumber)) } - // Packet is already manipulated. - if pkt.NatDone { - return RuleAccept, 0 - } - - // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { - return RuleDrop, 0 - } - switch hook { case Postrouting, Input: case Prerouting, Output, Forward: @@ -206,37 +219,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - header.UDP(pkt.TransportHeader().View()), - true, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - st.Port, - st.Addr, - ) - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 - } + return natAction(pkt, hook, r, st.Port, st.Addr, false /* dnat */) +} - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) - } +// MasqueradeTarget modifies the source port/IP in the outgoing packets. +type MasqueradeTarget struct { + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if mt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + mt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Postrouting: + case Prerouting, Input, Forward, Output: + panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook)) default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + // addressEP is expected to be set for the postrouting hook. + ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */) + if ep == nil { + // No address exists that we can use as a source address. return RuleDrop, 0 } - return RuleAccept, 0 + address := ep.AddressWithPrefix().Address + ep.DecRef() + return natAction(pkt, hook, r, 0 /* port */, address, false /* dnat */) } func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 66e5f22ac..b22024667 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -81,17 +81,6 @@ const ( // // +stateify savable type IPTables struct { - // mu protects v4Tables, v6Tables, and modified. - mu sync.RWMutex - // v4Tables and v6tables map tableIDs to tables. They hold builtin - // tables only, not user tables. mu must be locked for accessing. - v4Tables [NumTables]Table - v6Tables [NumTables]Table - // modified is whether tables have been modified at least once. It is - // used to elide the iptables performance overhead for workloads that - // don't utilize iptables. - modified bool - // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. @@ -101,6 +90,21 @@ type IPTables struct { // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} + + mu sync.RWMutex + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. + // + // +checklocks:mu + v4Tables [NumTables]Table + // +checklocks:mu + v6Tables [NumTables]Table + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + // + // +checklocks:mu + modified bool } // VisitTargets traverses all the targets of all tables and replaces each with @@ -352,5 +356,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 4d5431da1..40b33b6b5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Should get the address immediately since we should not have performed @@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + }, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Make sure the address does not resolve before the resolution time has @@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet @@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) { // Attempting to add the address again should not fail if the address's // state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } }) } @@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) { // Add addresses for each NIC. addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix1, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err) } addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix2, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err) } expectDADEvent(nicID2, addr2) addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix3, + } + if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err) } expectDADEvent(nicID3, addr3) @@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { continue } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: test.addrs[j].Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{} @@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr2, } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err) } // addr2 should be more preferred now since it is at the front of the primary // list. @@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr} + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err) } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) @@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) { } if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a796942ab..e251e3b24 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -97,6 +97,8 @@ type packetEndpointList struct { mu sync.RWMutex // eps is protected by mu, but the contained PacketEndpoint values are not. + // + // +checklocks:mu eps []PacketEndpoint } @@ -117,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) { } } +func (p *packetEndpointList) len() int { + p.mu.RLock() + defer p.mu.RUnlock() + return len(p.eps) +} + // forEach calls fn with each endpoints in p while holding the read lock on p. func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { p.mu.RLock() @@ -157,14 +165,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 - // Register supported packet and network endpoint protocols. - for _, netProto := range header.Ethertypes { - nic.packetEPs.eps[netProto] = new(packetEndpointList) - } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.packetEPs.eps[netNum] = new(packetEndpointList) - netEP := netProto.NewEndpoint(nic, nic) nic.networkEndpoints[netNum] = netEP @@ -514,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return &tcpip.ErrUnknownProtocol{} @@ -525,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return &tcpip.ErrNotSupported{} } - addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -831,24 +833,9 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt transProto := state.proto - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled - // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // ICMP packets don't have their TransportHeader fields set yet, parse it - // here. See icmp/protocol.go:protocol.Parse for a full explanation. - if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - // ICMP packets may be longer, but until icmp.Parse is implemented, here - // we parse it using the minimum size. - if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stats.malformedL4RcvdPackets.Increment() - // We consider a malformed transport packet handled because there is - // nothing the caller can do. - return TransportPacketHandled - } - } else if !transProto.Parse(pkt) { - n.stats.malformedL4RcvdPackets.Increment() - return TransportPacketHandled - } + n.stats.malformedL4RcvdPackets.Increment() + return TransportPacketHandled } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) @@ -974,7 +961,8 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa eps, ok := n.packetEPs.eps[netProto] if !ok { - return &tcpip.ErrNotSupported{} + eps = new(packetEndpointList) + n.packetEPs.eps[netProto] = eps } eps.add(ep) @@ -990,6 +978,9 @@ func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep return } eps.remove(ep) + if eps.len() == 0 { + delete(n.packetEPs.eps, netProto) + } } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 5cb342f78..c8ad93f29 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. -func (*testIPv6Protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements NetworkProtocol.ParseAddresses. func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 29c22bfd4..c4a4bbd22 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -126,9 +126,13 @@ type PacketBuffer struct { EgressRoute RouteInfo GSOOptions GSO - // NatDone indicates if the packet has been manipulated as per NAT - // iptables rule. - NatDone bool + // SNATDone indicates if the packet's source has been manipulated as per + // iptables NAT table. + SNATDone bool + + // DNATDone indicates if the packet's destination has been manipulated as per + // iptables NAT table. + DNATDone bool // PktType indicates the SockAddrLink.PacketType of the packet as defined in // https://www.man7.org/linux/man-pages/man7/packet.7.html. @@ -143,6 +147,8 @@ type PacketBuffer struct { // NetworkPacketInfo holds an incoming packet's network-layer information. NetworkPacketInfo NetworkPacketInfo + + tuple *tuple } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -296,12 +302,14 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { Owner: pk.Owner, GSOOptions: pk.GSOOptions, NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, + DNATDone: pk.DNATDone, + SNATDone: pk.SNATDone, TransportProtocolNumber: pk.TransportProtocolNumber, PktType: pk.PktType, NICID: pk.NICID, RXTransportChecksumValidated: pk.RXTransportChecksumValidated, NetworkPacketInfo: pk.NetworkPacketInfo, + tuple: pk.tuple, } } @@ -329,15 +337,41 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { buf: pk.buf.Clone(), // Treat unfilled header portion as reserved. reserved: pk.AvailableHeaderBytes(), + tuple: pk.tuple, + } + return newPk +} + +// DeepCopyForForwarding creates a deep copy of the packet buffer for +// forwarding. +// +// The returned packet buffer will have the network and transport headers +// set if the original packet buffer did. +func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer { + newPk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: reservedHeaderBytes, + Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(), + IsForwardedPacket: true, + }) + + { + consumeBytes := pk.NetworkHeader().View().Size() + if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed { + panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes)) + } + newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber } - // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to - // maintain this flag in the packet. Currently conntrack needs this flag to - // tell if a noop connection should be inserted at Input hook. Once conntrack - // redefines the manipulation field as mutable, we won't need the special noop - // connection. - if pk.NatDone { - newPk.NatDone = true + + { + consumeBytes := pk.TransportHeader().View().Size() + if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed { + panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes)) + } + newPk.TransportProtocolNumber = pk.TransportProtocolNumber } + + newPk.tuple = pk.tuple + return newPk } @@ -389,13 +423,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// DeleteFront removes count from the beginning of d. It panics if count > -// d.Size(). All backing storage references after the front of the d are -// invalidated. -func (d PacketData) DeleteFront(count int) { - if !d.pk.buf.Remove(d.pk.dataOffset(), count) { - panic("count > d.Size()") +// Consume is the same as PullUp except that is additionally consumes the +// returned bytes. Subsequent PullUp or Consume will not return these bytes. +func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) { + v, ok := d.PullUp(size) + if ok { + d.pk.consumed += size } + return v, ok } // CapLength reduces d to at most length bytes. diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 87b023445..c376ed1a1 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) { } } -func TestPacketBufferClone(t *testing.T) { - data := concatViews(makeView(20), makeView(30), makeView(40)) - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - bytesToDelete := 30 - originalSize := data.Size() - - clonedPks := []*PacketBuffer{ - pk.Clone(), - pk.CloneToInbound(), - } - pk.Data().DeleteFront(bytesToDelete) - if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want { - t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) - } - for _, clonedPk := range clonedPks { - if got := clonedPk.Data().Size(); got != originalSize { - t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) - } - } -} - func TestPacketHeaderConsume(t *testing.T) { for _, test := range []struct { name string @@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) { } }) - // DeleteFront + // Consume. for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().DeleteFront(n) + v, ok := pkt.Data().Consume(n) + if !ok { + t.Fatalf("Consume failed") + } + if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) { + t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want) + } checkData(t, pkt, []byte(tc.data)[n:]) }) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 113baaaae..31b3a554d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int const ( // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. + // endpoint for new connections with no local address. CanBePrimaryEndpoint PrimaryEndpointBehavior = iota // FirstPrimaryEndpoint indicates the endpoint should be the first @@ -332,6 +331,19 @@ const ( NeverPrimaryEndpoint ) +func (peb PrimaryEndpointBehavior) String() string { + switch peb { + case CanBePrimaryEndpoint: + return "CanBePrimaryEndpoint" + case FirstPrimaryEndpoint: + return "FirstPrimaryEndpoint" + case NeverPrimaryEndpoint: + return "NeverPrimaryEndpoint" + default: + panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb)) + } +} + // AddressConfigType is the method used to add an address. type AddressConfigType int @@ -351,6 +363,14 @@ const ( AddressConfigSlaacTemp ) +// AddressProperties contains additional properties that can be configured when +// adding an address. +type AddressProperties struct { + PEB PrimaryEndpointBehavior + ConfigType AddressConfigType + Deprecated bool +} + // AssignableAddressEndpoint is a reference counted address endpoint that may be // assigned to a NetworkEndpoint. type AssignableAddressEndpoint interface { @@ -457,7 +477,7 @@ type AddressableEndpoint interface { // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. @@ -685,9 +705,6 @@ type NetworkProtocol interface { // than this targeted at this protocol. MinimumPacketSize() int - // DefaultPrefixLen returns the protocol's default prefix length. - DefaultPrefixLen() int - // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cb741e540..a05fd7036 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -238,7 +238,7 @@ type Options struct { // DefaultIPTables is an optional iptables rules constructor that is called // if IPTables is nil. If both fields are nil, iptables will allow all // traffic. - DefaultIPTables func(uint32) *IPTables + DefaultIPTables func(seed uint32, clock tcpip.Clock) *IPTables // SecureRNG is a cryptographically secure random number generator. SecureRNG io.Reader @@ -358,7 +358,7 @@ func New(opts Options) *Stack { if opts.DefaultIPTables == nil { opts.DefaultIPTables = DefaultTables } - opts.IPTables = opts.DefaultIPTables(seed) + opts.IPTables = opts.DefaultIPTables(seed, clock) } opts.NUDConfigs.resetInvalidFields() @@ -375,7 +375,7 @@ func New(opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), + icmpRateLimiter: NewICMPRateLimiter(clock), seed: seed, nudConfigs: opts.NUDConfigs, uniqueIDGenerator: opts.UniqueID, @@ -916,46 +916,9 @@ type NICStateFlags struct { Loopback bool } -// AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { - return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) -} - -// AddAddressWithPrefix is the same as AddAddress, but allows you to specify -// the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { - ap := tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: addr, - } - return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) -} - -// AddProtocolAddress adds a new network-layer protocol address to the -// specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { - return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) -} - -// AddAddressWithOptions is the same as AddAddress, but allows you to specify -// whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return &tcpip.ErrUnknownProtocol{} - } - return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb) -} - -// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows -// you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +// AddProtocolAddress adds an address to the specified NIC, possibly with extra +// properties. +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -964,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return &tcpip.ErrUnknownNICID{} } - return nic.addAddress(protocolAddress, peb) + return nic.addAddress(protocolAddress, properties) } // RemoveAddress removes an existing network-layer address from the specified @@ -1902,12 +1865,6 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // ICMP packets don't have their TransportHeader fields set yet, parse it - // here. See icmp/protocol.go:protocol.Parse for a full explanation. - if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - return ParsedOK - } - pkt.TransportProtocolNumber = protocol // Parse the transport header if present. state, ok := s.transportProtocols[protocol] diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3089c0ef4..f5a35eac4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().Consume(fakeNetHeaderLen) if !ok { return } - // DeleteFront invalidates slices. Make a copy before trimming. - nb := append([]byte(nil), hdr...) - pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), // Nothing checks the error. nil, /* transport error */ pkt, @@ -158,8 +155,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + transProtoNum := tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]) + switch err := f.proto.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(transProtoNum, pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -221,6 +228,8 @@ func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. type fakeNetworkProtocol struct { + stack *stack.Stack + packetCount [10]int sendPacketCount [10]int defaultTTL uint8 @@ -234,10 +243,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } -func (*fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { return f.packetCount[int(intfAddr)%len(f.packetCount)] } @@ -306,8 +311,8 @@ func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.forwarding = v } -func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { - return &fakeNetworkProtocol{} +func fakeNetFactory(s *stack.Stack) stack.NetworkProtocol { + return &fakeNetworkProtocol{stack: s} } // linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify @@ -349,12 +354,26 @@ func TestNetworkReceive(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -517,8 +536,15 @@ func TestNetworkSend(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Make sure that the link-layer endpoint received the outbound packet. @@ -538,12 +564,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -551,12 +591,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -812,8 +866,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } ep2 := channel.New(1, defaultMTU, "") @@ -821,8 +882,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr2, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } // Set a route table that sends all packets with odd destination @@ -978,12 +1046,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -991,12 +1073,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -1058,8 +1154,15 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1108,8 +1211,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1242,8 +1352,15 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1270,8 +1387,8 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1310,8 +1427,8 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1453,8 +1570,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) @@ -1510,8 +1634,15 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { @@ -1633,8 +1764,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) + if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { @@ -1678,13 +1809,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { t.Fatalf("CreateNIC failed: %s", err) } nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) + if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) + if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err) } // Set the initial route table. @@ -1726,7 +1857,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1}, }, rt..., ) @@ -1808,8 +1939,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) } - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: anyAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { @@ -1886,22 +2024,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Add an address and in case of a primary one include a // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) + properties := stack.AddressProperties{PEB: behavior} if behavior == stack.CanBePrimaryEndpoint { protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, + Protocol: fakeNetNumber, + AddressWithPrefix: address.WithPrefix(), } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } // Remember the address/prefix. primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } } } @@ -1996,8 +2139,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { PrefixLen: tc.prefixLen, }, } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) + if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -2047,33 +2190,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } } -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ @@ -2084,96 +2200,43 @@ func TestAddProtocolAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - addrLenRange := []int{4, 16} behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) + configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp} + deprecatedRange := []bool{false, true} + wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange)) var addrGen addressGenerator for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) + for _, configType := range configTypeRange { + for _, deprecated := range deprecatedRange { + address := addrGen.next(addrLen) + properties := stack.AddressProperties{ + PEB: behavior, + ConfigType: configType, + Deprecated: deprecated, + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err) + } + wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, + }) } - expectedAddresses = append(expectedAddresses, protocolAddress) } } } gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) + verifyAddresses(t, wantAddresses, gotAddresses) } func TestCreateNICWithOptions(t *testing.T) { @@ -2290,8 +2353,15 @@ func TestNICStats(t *testing.T) { if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed: ", err) } - if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: nic.addr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err) } { @@ -2735,8 +2805,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // be returned by a call to GetMainNICAddress; // else, it should. const address1 = tcpip.Address("\x01") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + properties := stack.AddressProperties{PEB: pi} + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err) } addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) if err != nil { @@ -2785,16 +2863,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. const address3 = tcpip.Address("\x03") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) - + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address3, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err) } // Add back the address we removed earlier and // make sure the new peb was respected. // (The address should just be promoted now). - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: ps} + if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { @@ -3096,8 +3189,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: a.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -3203,8 +3300,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // The NIC should have joined addr1's solicited node multicast address. @@ -3359,8 +3460,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { PrefixLen: 128, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } // Address should be in the list of all addresses. @@ -3687,8 +3788,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) @@ -3750,8 +3851,8 @@ func TestResolveWith(t *testing.T) { PrefixLen: 24, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) @@ -3792,8 +3893,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -3881,8 +3989,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -3990,44 +4098,44 @@ func TestFindRouteWithForwarding(t *testing.T) { ) type netCfg struct { - proto tcpip.NetworkProtocolNumber - factory stack.NetworkProtocolFactory - nic1Addr tcpip.Address - nic2Addr tcpip.Address - remoteAddr tcpip.Address + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1AddrWithPrefix tcpip.AddressWithPrefix + nic2AddrWithPrefix tcpip.AddressWithPrefix + remoteAddr tcpip.Address } fakeNetCfg := netCfg{ - proto: fakeNetNumber, - factory: fakeNetFactory, - nic1Addr: nic1Addr, - nic2Addr: nic2Addr, - remoteAddr: remoteAddr, + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen}, + nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen}, + remoteAddr: remoteAddr, } globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: llAddr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: globalIPv6Addr1, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: llAddr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: globalIPv6Addr1, } ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: llAddr1, - remoteAddr: llAddr2, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: llAddr1.WithPrefix(), + remoteAddr: llAddr2, } ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", } tests := []struct { @@ -4036,8 +4144,8 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg netCfg forwardingEnabled bool - addrNIC tcpip.NICID - localAddr tcpip.Address + addrNIC tcpip.NICID + localAddrWithPrefix tcpip.AddressWithPrefix findRouteErr tcpip.Error dependentOnForwarding bool @@ -4047,7 +4155,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4056,7 +4164,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4065,7 +4173,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4074,7 +4182,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4083,7 +4191,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4092,7 +4200,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4101,7 +4209,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4110,7 +4218,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4118,7 +4226,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4126,7 +4234,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4134,7 +4242,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4142,7 +4250,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: true, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4166,7 +4274,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on different NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4174,7 +4282,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4182,7 +4290,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4190,7 +4298,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4198,7 +4306,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4206,7 +4314,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4214,7 +4322,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4222,7 +4330,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4230,7 +4338,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4238,7 +4346,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4246,7 +4354,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4268,12 +4376,20 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) } - if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic1AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } - if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic2AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { @@ -4282,20 +4398,20 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) - r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { - t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { return } - if r.LocalAddress() != test.localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr) + if r.LocalAddress() != test.localAddrWithPrefix.Address { + t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address) } if r.RemoteAddress() != test.netCfg.remoteAddr { t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr) @@ -4318,8 +4434,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if !ok { t.Fatal("packet not sent through ep2") } - if pkt.Route.LocalAddress != test.localAddr { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address) } if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index dc7289441..a941091b0 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -289,6 +289,12 @@ type TCPSenderState struct { // RACKState holds the state related to RACK loss detection algorithm. RACKState TCPRACKState + + // RetransmitTS records the timestamp used to detect spurious recovery. + RetransmitTS uint32 + + // SpuriousRecovery indicates if the sender entered recovery spuriously. + SpuriousRecovery bool } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 824cf6526..3474c292a 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -32,11 +32,13 @@ type protocolIDs struct { // transportEndpoints manages all endpoints of a given protocol. It has its own // mutex so as to reduce interference between protocols. type transportEndpoints struct { - // mu protects all fields of the transportEndpoints. - mu sync.RWMutex + mu sync.RWMutex + // +checklocks:mu endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. + // + // +checklocks:mu rawEndpoints []RawTransportEndpoint } @@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { // descending order of match quality. If a call to yield returns false, // iterEndpointsLocked stops iteration and returns immediately. // -// Preconditions: eps.mu must be locked. +// +checklocksread:eps.mu func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in // descending order of match quality. // -// Preconditions: eps.mu must be locked. +// +checklocksread:eps.mu func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { var matchedEPs []*endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) [] // findEndpointLocked returns the endpoint that most closely matches the given id. // -// Preconditions: eps.mu must be locked. +// +checklocksread:eps.mu func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { var matchedEP *endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo } type endpointsByNIC struct { - mu sync.RWMutex - endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 + + mu sync.RWMutex + // +checklocks:mu + endpoints map[tcpip.NICID]*multiPortEndpoint } func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { @@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet return true } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNIC.seed) + transEP := mpep.selectEndpoint(id, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() @@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) + mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` demux *transportDemuxer netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber + flags ports.FlagCounter + + mu sync.RWMutex `state:"nosave"` // endpoints stores the transport endpoints in the order in which they // were bound. This is required for UDP SO_REUSEADDR. + // + // +checklocks:mu endpoints []TransportEndpoint - flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpoints) == 1 { - return mpep.endpoints[0] +func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint { + ep.mu.RLock() + defer ep.mu.RUnlock() + + if len(ep.endpoints) == 1 { + return ep.endpoints[0] } - if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { - return mpep.endpoints[len(mpep.endpoints)-1] + if ep.flags.SharedFlags().ToFlags().Effective().MostRecent { + return ep.endpoints[len(ep.endpoints)-1] } payload := []byte{ @@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) - return mpep.endpoints[idx] + idx := reciprocalScale(hash, uint32(len(ep.endpoints))) + return ep.endpoints[idx] } func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { @@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN } } - ep := selectEndpoint(id, mpep, epsByNIC.seed) + ep := mpep.selectEndpoint(id, epsByNIC.seed) epsByNIC.mu.RUnlock() return ep } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 45b09110d..cd3a8c25a 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -35,7 +35,7 @@ import ( const ( testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") testSrcAddrV4 = "\x0a\x00\x00\x01" testDstAddrV4 = "\x0a\x00\x00\x02" @@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { - t.Fatalf("AddAddress IPv4 failed: %s", err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { - t.Fatalf("AddAddress IPv6 failed: %s", err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: testDstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err) } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 839178809..51870d03f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -331,8 +331,11 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) - return ok + if _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen); ok { + pkt.TransportProtocolNumber = fakeTransNumber + return true + } + return false } func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { @@ -357,8 +360,15 @@ func TestTransportReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -428,8 +438,15 @@ func TestTransportControlReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -497,8 +514,15 @@ func TestTransportSend(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 55683b4fb..460a6afaf 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -19,7 +19,7 @@ // The starting point is the creation and configuration of a stack. A stack can // be created by calling the New() function of the tcpip/stack/stack package; // configuring a stack involves creating NICs (via calls to Stack.CreateNIC()), -// adding network addresses (via calls to Stack.AddAddress()), and +// adding network addresses (via calls to Stack.AddProtocolAddress()), and // setting a route table (via a call to Stack.SetRouteTable()). // // Once a stack is configured, endpoints can be created by calling @@ -423,9 +423,9 @@ type ControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packet used to create - // the read data was received. - Timestamp int64 + // Timestamp is the time that the last packet used to create the read data + // was received. + Timestamp time.Time `state:".(int64)"` // HasInq indicates whether Inq is valid/set. HasInq bool @@ -451,6 +451,12 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set. + HasIPv6PacketInfo bool + + // IPv6PacketInfo holds interface and address data on an incoming packet. + IPv6PacketInfo IPv6PacketInfo + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is // set. HasOriginalDstAddress bool @@ -465,10 +471,10 @@ type ControlMessages struct { // PacketOwner is used to get UID and GID of the packet. type PacketOwner interface { - // UID returns KUID of the packet. + // KUID returns KUID of the packet. KUID() uint32 - // GID returns KGID of the packet. + // KGID returns KGID of the packet. KGID() uint32 } @@ -1164,6 +1170,14 @@ type IPPacketInfo struct { DestinationAddr Address } +// IPv6PacketInfo is the message structure for IPV6_PKTINFO. +// +// +stateify savable +type IPv6PacketInfo struct { + Addr Address + NIC NICID +} + // SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to // get/set the default, min and max send buffer sizes. type SendBufferSizeOption struct { @@ -1231,11 +1245,11 @@ type Route struct { // String implements the fmt.Stringer interface. func (r Route) String() string { var out strings.Builder - fmt.Fprintf(&out, "%s", r.Destination) + _, _ = fmt.Fprintf(&out, "%s", r.Destination) if len(r.Gateway) > 0 { - fmt.Fprintf(&out, " via %s", r.Gateway) + _, _ = fmt.Fprintf(&out, " via %s", r.Gateway) } - fmt.Fprintf(&out, " nic %d", r.NIC) + _, _ = fmt.Fprintf(&out, " nic %d", r.NIC) return out.String() } @@ -1255,6 +1269,8 @@ type TransportProtocolNumber uint32 type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. +// +// +stateify savable type StatCounter struct { count atomicbitops.AlignedAtomicUint64 } @@ -1270,7 +1286,7 @@ func (s *StatCounter) Decrement() { } // Value returns the current value of the counter. -func (s *StatCounter) Value(name ...string) uint64 { +func (s *StatCounter) Value(...string) uint64 { return s.count.Load() } @@ -1849,6 +1865,10 @@ type TCPStats struct { // SegmentsAckedWithDSACK is the number of segments acknowledged with // DSACK. SegmentsAckedWithDSACK *StatCounter + + // SpuriousRecovery is the number of times the connection entered loss + // recovery spuriously. + SpuriousRecovery *StatCounter } // UDPStats collects UDP-specific stats. @@ -1981,6 +2001,8 @@ type Stats struct { } // ReceiveErrors collects packet receive errors within transport endpoint. +// +// +stateify savable type ReceiveErrors struct { // ReceiveBufferOverflow is the number of received packets dropped // due to the receive buffer being full. @@ -1998,8 +2020,10 @@ type ReceiveErrors struct { ChecksumErrors StatCounter } -// SendErrors collects packet send errors within the transport layer for -// an endpoint. +// SendErrors collects packet send errors within the transport layer for an +// endpoint. +// +// +stateify savable type SendErrors struct { // SendToNetworkFailed is the number of packets failed to be written to // the network endpoint. @@ -2010,6 +2034,8 @@ type SendErrors struct { } // ReadErrors collects segment read errors from an endpoint read call. +// +// +stateify savable type ReadErrors struct { // ReadClosed is the number of received packet drops because the endpoint // was shutdown for read. @@ -2025,6 +2051,8 @@ type ReadErrors struct { } // WriteErrors collects packet write errors from an endpoint write call. +// +// +stateify savable type WriteErrors struct { // WriteClosed is the number of packet drops because the endpoint // was shutdown for write. @@ -2040,6 +2068,8 @@ type WriteErrors struct { } // TransportEndpointStats collects statistics about the endpoint. +// +// +stateify savable type TransportEndpointStats struct { // PacketsReceived is the number of successful packet receives. PacketsReceived StatCounter diff --git a/pkg/tcpip/tcpip_state.go b/pkg/tcpip/tcpip_state.go new file mode 100644 index 000000000..1953e24a1 --- /dev/null +++ b/pkg/tcpip/tcpip_state.go @@ -0,0 +1,27 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "time" +) + +func (c *ControlMessages) saveTimestamp() int64 { + return c.Timestamp.UnixNano() +} + +func (c *ControlMessages) loadTimestamp(nsec int64) { + c.Timestamp = time.Unix(0, nsec) +} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 181ef799e..99f4d4d0e 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -34,12 +34,16 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -139,3 +143,25 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "istio_test", + size = "small", + srcs = ["istio_test.go"], + deps = [ + "//pkg/context", + "//pkg/rand", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/header", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/pipe", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 92fa6257d..6e1d4720d 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr, } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr, + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) { addr: utils.RouterNIC2IPv6Addr, }, } { - if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) + if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err) } } diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index f9ab7d0af..957a779bf 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -15,19 +15,24 @@ package iptables_test import ( + "bytes" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) type inputIfNameMatcher struct { @@ -49,10 +54,10 @@ const ( nicName = "nic1" anotherNicName = "nic2" linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - srcAddrV4 = "\x0a\x00\x00\x01" - dstAddrV4 = "\x0a\x00\x00\x02" - srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01") + dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02") + srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") payloadSize = 20 ) @@ -66,8 +71,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: dstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, e } @@ -82,8 +91,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: dstAddrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, e } @@ -601,11 +614,19 @@ func TestIPTableWritePackets(t *testing.T) { if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: srcAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) + } + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: srcAddrV4.WithPrefix(), } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) + if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) } s.SetRouteTable([]tcpip.Route{ @@ -856,11 +877,19 @@ func TestForwardingHook(t *testing.T) { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(), } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -1037,22 +1066,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) { if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err) } - if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err) } e2 := channel.New(1, header.IPv6MinimumMTU, "") if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) } - if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err) } - if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -1132,3 +1161,621 @@ func TestInputHookWithLocalForwarding(t *testing.T) { }) } } + +func TestNAT(t *testing.T) { + const listenPort uint16 = 8080 + + type endpointAndAddresses struct { + serverEP tcpip.Endpoint + serverAddr tcpip.FullAddress + serverReadableCH chan struct{} + serverConnectAddr tcpip.Address + + clientEP tcpip.Endpoint + clientAddr tcpip.Address + clientReadableCH chan struct{} + clientConnectAddr tcpip.FullAddress + } + + newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + t.Cleanup(func() { + wq.EventUnregister(&we) + }) + + ep, err := s.NewEndpoint(transProto, netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) + } + t.Cleanup(ep.Close) + + return ep, ch + } + + setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + table := ipt.GetTable(stack.NATID, ipv6) + ruleIdx := table.BuiltinChains[hook] + table.Rules[ruleIdx].Filter = filter + table.Rules[ruleIdx].Target = target + // Make sure the packet is not dropped by the next rule. + table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } + } + + setupDNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) { + t.Helper() + + setupNAT( + t, + s, + netProto, + stack.Prerouting, + stack.IPHeaderFilter{ + Protocol: transProto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + target) + } + + setupSNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) { + t.Helper() + + setupNAT( + t, + s, + netProto, + stack.Postrouting, + stack.IPHeaderFilter{ + Protocol: transProto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + target) + } + + type natType struct { + name string + setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) + } + + snatTypes := []natType{ + { + name: "SNAT", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address) { + t.Helper() + + setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr}) + }, + }, + { + name: "Masquerade", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) { + t.Helper() + + setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto}) + }, + }, + } + dnatTypes := []natType{ + { + name: "Redirect", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) { + t.Helper() + + setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: listenPort}) + }, + }, + { + name: "DNAT", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address) { + t.Helper() + + setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort}) + }, + }, + } + + setupTwiceNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, snatTarget stack.Target) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + + table := stack.Table{ + Rules: []stack.Rule{ + // Prerouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transProto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + Target: &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort}, + }, + { + Target: &stack.AcceptTarget{}, + }, + + // Input + { + Target: &stack.AcceptTarget{}, + }, + + // Forward + { + Target: &stack.AcceptTarget{}, + }, + + // Output + { + Target: &stack.AcceptTarget{}, + }, + + // Postrouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transProto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + Target: snatTarget, + }, + { + Target: &stack.AcceptTarget{}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 2, + stack.Forward: 3, + stack.Output: 4, + stack.Postrouting: 5, + }, + } + + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } + } + twiceNATTypes := []natType{ + { + name: "DNAT-Masquerade", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) { + t.Helper() + + setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.MasqueradeTarget{NetworkProtocol: netProto}) + }, + }, + { + name: "DNAT-SNAT", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) { + t.Helper() + + setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr}) + }, + }, + } + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + // Setups up the stacks in such a way that: + // + // - Host2 is the client for all tests. + // - When performing SNAT only: + // + Host1 is the server. + // + NAT will transform client-originating packets' source addresses to + // the router's NIC1's address before reaching Host1. + // - When performing DNAT only: + // + Router is the server. + // + Client will send packets directed to Host1. + // + NAT will transform client-originating packets' destination addresses + // to the router's NIC2's address. + // - When performing Twice-NAT: + // + Host1 is the server. + // + Client will send packets directed to router's NIC2. + // + NAT will transform client originating packets' destination addresses + // to Host1's address. + // + NAT will transform client-originating packets' source addresses to + // the router's NIC1's address before reaching Host1. + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses + natTypes []natType + }{ + { + name: "IPv4 SNAT", + netProto: ipv4.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + listenerStack := host1Stack + serverAddr := tcpip.FullAddress{ + Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address + clientConnectPort := serverAddr.Port + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: snatTypes, + }, + { + name: "IPv4 DNAT", + netProto: ipv4.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + // If we are performing DNAT, then the packet will be redirected + // to the router. + listenerStack := routerStack + serverAddr := tcpip.FullAddress{ + Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address + // DNAT will update the destination port to what the server is + // bound to. + clientConnectPort := serverAddr.Port + 1 + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: dnatTypes, + }, + { + name: "IPv4 Twice-NAT", + netProto: ipv4.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + listenerStack := host1Stack + serverAddr := tcpip.FullAddress{ + Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address + clientConnectPort := serverAddr.Port + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: twiceNATTypes, + }, + { + name: "IPv6 SNAT", + netProto: ipv6.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + listenerStack := host1Stack + serverAddr := tcpip.FullAddress{ + Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address + clientConnectPort := serverAddr.Port + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: snatTypes, + }, + { + name: "IPv6 DNAT", + netProto: ipv6.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + // If we are performing DNAT, then the packet will be redirected + // to the router. + listenerStack := routerStack + serverAddr := tcpip.FullAddress{ + Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address + // DNAT will update the destination port to what the server is + // bound to. + clientConnectPort := serverAddr.Port + 1 + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: dnatTypes, + }, + { + name: "IPv6 Twice-NAT", + netProto: ipv6.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + listenerStack := host1Stack + serverAddr := tcpip.FullAddress{ + Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + Port: listenPort, + } + serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address + clientConnectPort := serverAddr.Port + ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: serverAddr, + serverReadableCH: ep1WECH, + serverConnectAddr: serverConnectAddr, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + clientConnectAddr: tcpip.FullAddress{ + Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, + Port: clientConnectPort, + }, + } + }, + natTypes: twiceNATTypes, + }, + } + + subTests := []struct { + name string + proto tcpip.TransportProtocolNumber + expectedConnectErr tcpip.Error + setupServer func(t *testing.T, ep tcpip.Endpoint) + setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + needRemoteAddr bool + }{ + { + name: "UDP", + proto: udp.ProtocolNumber, + expectedConnectErr: nil, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + if err := ep.Connect(clientAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) + } + return nil, nil + }, + needRemoteAddr: true, + }, + { + name: "TCP", + proto: tcp.ProtocolNumber, + expectedConnectErr: &tcpip.ErrConnectStarted{}, + setupServer: func(t *testing.T, ep tcpip.Endpoint) { + t.Helper() + + if err := ep.Listen(1); err != nil { + t.Fatalf("ep.Listen(1): %s", err) + } + }, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + var addr tcpip.FullAddress + for { + newEP, wq, err := ep.Accept(&addr) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Accept(_): %s", err) + } + if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( + "NIC", + )); diff != "" { + t.Errorf("accepted address mismatch (-want +got):\n%s", diff) + } + + we, newCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + return newEP, newCH + } + }, + needRemoteAddr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + for _, natType := range test.natTypes { + t.Run(natType.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) + natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr) + + if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } + { + err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff) + } + } + serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr} + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + serverConnectAddr.Port = addr.Port + } + + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch + } + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + var r bytes.Reader + r.Reset(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(&r, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } + } + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + var buf bytes.Buffer + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } + + { + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + read(serverCH, serverEP, data, serverConnectAddr) + } + + { + data := []byte{5, 6, 7, 8, 9, 10, 11, 12} + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr) + } + }) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/istio_test.go b/pkg/tcpip/tests/integration/istio_test.go new file mode 100644 index 000000000..95d994ef8 --- /dev/null +++ b/pkg/tcpip/tests/integration/istio_test.go @@ -0,0 +1,365 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package istio_test + +import ( + "fmt" + "io" + "net" + "net/http" + "strconv" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" +) + +// testContext encapsulates the state required to run tests that simulate +// an istio like environment. +// +// A diagram depicting the setup is shown below. +// +-----------------------------------------------------------------------+ +// | +-------------------------------------------------+ | +// | + ----------+ | + -----------------+ PROXY +----------+ | | +// | | clientEP | | | serverListeningEP|--accepted-> | serverEP |-+ | | +// | + ----------+ | + -----------------+ +----------+ | | | +// | | -------|-------------+ +----------+ | | | +// | | | | | proxyEP |-+ | | +// | +-----redirect | +----------+ | | +// | + ------------+---|------+---+ | +// | | | +// | Local Stack. | | +// +-------------------------------------------------------|---------------+ +// | +// +-----------------------------------------------------------------------+ +// | remoteStack | | +// | +-------------SYN ---------------| | +// | | | | +// | +-------------------|--------------------------------|-_---+ | +// | | + -----------------+ + ----------+ | | | +// | | | remoteListeningEP|--accepted--->| remoteEP |<++ | | +// | | + -----------------+ + ----------+ | | +// | | Remote HTTP Server | | +// | +----------------------------------------------------------+ | +// +-----------------------------------------------------------------------+ +// +type testContext struct { + // localServerListener is the listening port for the server which will proxy + // all traffic to the remote EP. + localServerListener *gonet.TCPListener + + // remoteListenListener is the remote listening endpoint that will receive + // connections from server. + remoteServerListener *gonet.TCPListener + + // localStack is the stack used to create client/server endpoints and + // also the stack on which we install NAT redirect rules. + localStack *stack.Stack + + // remoteStack is the stack that represents a *remote* server. + remoteStack *stack.Stack + + // defaultResponse is the response served by the HTTP server for all GET + defaultResponse []byte + + // requests. wg is used to wait for HTTP server and Proxy to terminate before + // returning from cleanup. + wg sync.WaitGroup +} + +func (ctx *testContext) cleanup() { + ctx.localServerListener.Close() + ctx.localStack.Close() + ctx.remoteServerListener.Close() + ctx.remoteStack.Close() + ctx.wg.Wait() +} + +const ( + localServerPort = 8080 + remoteServerPort = 9090 +) + +var ( + localIPv4Addr1 = testutil.MustParse4("10.0.0.1") + localIPv4Addr2 = testutil.MustParse4("10.0.0.2") + loopbackIPv4Addr = testutil.MustParse4("127.0.0.1") + remoteIPv4Addr1 = testutil.MustParse4("10.0.0.3") +) + +func newTestContext(t *testing.T) *testContext { + t.Helper() + localNIC, remoteNIC := pipe.New("" /* linkAddr1 */, "" /* linkAddr2 */) + + localStack := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: true, + }) + + remoteStack := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: true, + }) + + // Add loopback NIC. We need a loopback NIC as NAT redirect rule redirect to + // loopback address + specified port. + loopbackNIC := loopback.New() + const loopbackNICID = tcpip.NICID(1) + if err := localStack.CreateNIC(loopbackNICID, sniffer.New(loopbackNIC)); err != nil { + t.Fatalf("localStack.CreateNIC(%d, _): %s", loopbackNICID, err) + } + loopbackAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: loopbackIPv4Addr.WithPrefix(), + } + if err := localStack.AddProtocolAddress(loopbackNICID, loopbackAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", loopbackNICID, loopbackAddr, err) + } + + // Create linked NICs that connects the local and remote stack. + const localNICID = tcpip.NICID(2) + const remoteNICID = tcpip.NICID(3) + if err := localStack.CreateNIC(localNICID, sniffer.New(localNIC)); err != nil { + t.Fatalf("localStack.CreateNIC(%d, _): %s", localNICID, err) + } + if err := remoteStack.CreateNIC(remoteNICID, sniffer.New(remoteNIC)); err != nil { + t.Fatalf("remoteStack.CreateNIC(%d, _): %s", remoteNICID, err) + } + + for _, addr := range []tcpip.Address{localIPv4Addr1, localIPv4Addr2} { + localProtocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := localStack.AddProtocolAddress(localNICID, localProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", localNICID, localProtocolAddr, err) + } + } + + remoteProtocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: remoteIPv4Addr1.WithPrefix(), + } + if err := remoteStack.AddProtocolAddress(remoteNICID, remoteProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("remoteStack.AddProtocolAddress(%d, %+v, {}): %s", remoteNICID, remoteProtocolAddr, err) + } + + // Setup route table for local and remote stacks. + localStack.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4LoopbackSubnet, + NIC: loopbackNICID, + }, + { + Destination: header.IPv4EmptySubnet, + NIC: localNICID, + }, + }) + remoteStack.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: remoteNICID, + }, + }) + + const netProto = ipv4.ProtocolNumber + localServerAddress := tcpip.FullAddress{ + Port: localServerPort, + } + + localServerListener, err := gonet.ListenTCP(localStack, localServerAddress, netProto) + if err != nil { + t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", localServerAddress, netProto, err) + } + + remoteServerAddress := tcpip.FullAddress{ + Port: remoteServerPort, + } + remoteServerListener, err := gonet.ListenTCP(remoteStack, remoteServerAddress, netProto) + if err != nil { + t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", remoteServerAddress, netProto, err) + } + + // Initialize a random default response served by the HTTP server. + defaultResponse := make([]byte, 512<<10) + if _, err := rand.Read(defaultResponse); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + + tc := &testContext{ + localServerListener: localServerListener, + remoteServerListener: remoteServerListener, + localStack: localStack, + remoteStack: remoteStack, + defaultResponse: defaultResponse, + } + + tc.startServers(t) + return tc +} + +func (ctx *testContext) startServers(t *testing.T) { + ctx.wg.Add(1) + go func() { + defer ctx.wg.Done() + ctx.startHTTPServer() + }() + ctx.wg.Add(1) + go func() { + defer ctx.wg.Done() + ctx.startTCPProxyServer(t) + }() +} + +func (ctx *testContext) startTCPProxyServer(t *testing.T) { + t.Helper() + for { + conn, err := ctx.localServerListener.Accept() + if err != nil { + t.Logf("terminating local proxy server: %s", err) + return + } + // Start a goroutine to handle this inbound connection. + go func() { + remoteServerAddr := tcpip.FullAddress{ + Addr: remoteIPv4Addr1, + Port: remoteServerPort, + } + localServerAddr := tcpip.FullAddress{ + Addr: localIPv4Addr2, + } + serverConn, err := gonet.DialTCPWithBind(context.Background(), ctx.localStack, localServerAddr, remoteServerAddr, ipv4.ProtocolNumber) + if err != nil { + t.Logf("gonet.DialTCP(_, %+v, %d) = %s", remoteServerAddr, ipv4.ProtocolNumber, err) + return + } + proxy(conn, serverConn) + t.Logf("proxying completed") + }() + } +} + +// proxy transparently proxies the TCP payload from conn1 to conn2 +// and vice versa. +func proxy(conn1, conn2 net.Conn) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + io.Copy(conn2, conn1) + conn1.Close() + conn2.Close() + }() + wg.Add(1) + go func() { + io.Copy(conn1, conn2) + conn1.Close() + conn2.Close() + }() + wg.Wait() +} + +func (ctx *testContext) startHTTPServer() { + handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(ctx.defaultResponse)) + }) + s := &http.Server{ + Handler: handlerFunc, + } + s.Serve(ctx.remoteServerListener) +} + +func TestOutboundNATRedirect(t *testing.T) { + ctx := newTestContext(t) + defer ctx.cleanup() + + // Install an IPTable rule to redirect all TCP traffic with the sourceIP of + // localIPv4Addr1 to the tcp proxy port. + ipt := ctx.localStack.IPTables() + tbl := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := tbl.BuiltinChains[stack.Output] + tbl.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + Protocol: tcp.ProtocolNumber, + CheckProtocol: true, + Src: localIPv4Addr1, + SrcMask: tcpip.Address("\xff\xff\xff\xff"), + } + tbl.Rules[ruleIdx].Target = &stack.RedirectTarget{ + Port: localServerPort, + NetworkProtocol: ipv4.ProtocolNumber, + } + tbl.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, tbl, false /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, false): %s", stack.NATID, err) + } + + dialFunc := func(protocol, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("unable to parse address: %s, err: %s", address, err) + } + + remoteServerIP := net.ParseIP(host) + remoteServerPort, err := strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("unable to parse port from string %s, err: %s", port, err) + } + remoteAddress := tcpip.FullAddress{ + Addr: tcpip.Address(remoteServerIP.To4()), + Port: uint16(remoteServerPort), + } + + // Dial with an explicit source address bound so that the redirect rule will + // be able to correctly redirect these packets. + localAddr := tcpip.FullAddress{Addr: localIPv4Addr1} + return gonet.DialTCPWithBind(context.Background(), ctx.localStack, localAddr, remoteAddress, ipv4.ProtocolNumber) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + Dial: dialFunc, + }, + } + + serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Addr1), remoteServerPort) + response, err := httpClient.Get(serverURL) + if err != nil { + t.Fatalf("httpClient.Get(\"/\") failed: %s", err) + } + if got, want := response.StatusCode, http.StatusOK; got != want { + t.Fatalf("unexpected status code got: %d, want: %d", got, want) + } + body, err := io.ReadAll(response.Body) + if err != nil { + t.Fatalf("io.ReadAll(response.Body) failed: %s", err) + } + response.Body.Close() + if diff := cmp.Diff(body, ctx.defaultResponse); diff != "" { + t.Fatalf("unexpected response (-want +got): \n %s", diff) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 27caa0c28..95ddd8ec3 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err) + if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err) + if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err) + if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err) + if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ @@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) { Protocol: test.networkProtocolNumber, AddressWithPrefix: test.incomingAddr, } - if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err) } // Set up endpoint through which we will attempt to forward packets. @@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) { Protocol: test.networkProtocolNumber, AddressWithPrefix: test.outgoingAddr, } - if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index b2008f0b2..f33223e79 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) + if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ { @@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ { @@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) + if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ { @@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) { if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) + v4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr, } - if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) + if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err) + } + v6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr, + } + if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err) } if err := s.CreateNIC(nicID2, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ipv4Loopback, + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: header.IPv6Loopback.WithPrefix(), } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if test.forwarding { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 2d0a6e6a7..7753e7d6e 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) } ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err) } // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote @@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } var wq waiter.Queue @@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } // Set the route table so that UDP can find a NIC that is diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index ac3c703d4..422eb8408 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) { // request/reply packets. icmpDataOffset = 8 ) - ipv4Loopback := testutil.MustParse4("127.0.0.1") + ipv4Loopback := tcpip.AddressWithPrefix{ + Address: testutil.MustParse4("127.0.0.1"), + PrefixLen: 8, + } channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { @@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) { transProto tcpip.TransportProtocolNumber netProto tcpip.NetworkProtocolNumber linkEndpoint func() stack.LinkEndpoint - localAddr tcpip.Address + localAddr tcpip.AddressWithPrefix icmpBuf func(*testing.T) buffer.View expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) @@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, - localAddr: header.IPv6Loopback, + localAddr: header.IPv6Loopback.WithPrefix(), icmpBuf: ipv6ICMPBuf, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, @@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, - localAddr: utils.Ipv4Addr.Address, + localAddr: utils.Ipv4Addr, icmpBuf: ipv4ICMPBuf, checkLinkEndpoint: channelEPCheck, }, @@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, - localAddr: utils.Ipv6Addr.Address, + localAddr: utils.Ipv6Addr, icmpBuf: ipv6ICMPBuf, checkLinkEndpoint: channelEPCheck, }, @@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + if len(test.localAddr.Address) != 0 { + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.netProto, + AddressWithPrefix: test.localAddr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) { } defer ep.Close() - connAddr := tcpip.FullAddress{Addr: test.localAddr} + connAddr := tcpip.FullAddress{Addr: test.localAddr.Address} if err := ep.Connect(connAddr); err != test.expectedConnectErr { t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) } @@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) { if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { t.Errorf("received data mismatch (-want +got):\n%s", diff) } - if rr.RemoteAddr.Addr != test.localAddr { - t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) + if rr.RemoteAddr.Addr != test.localAddr.Address { + t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address) } test.checkLinkEndpoint(t, e) @@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) { } if subTest.addAddress { - if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err) } - if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 2e6ae55ea..c69410859 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -40,6 +40,14 @@ const ( Host2NICID = 4 ) +// Common NIC names used by tests. +const ( + Host1NICName = "host1NIC" + RouterNIC1Name = "routerNIC1" + RouterNIC2Name = "routerNIC2" + Host2NICName = "host2NIC" +) + // Common link addresses used by tests. const ( LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") @@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2) routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4) - if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err) + { + opts := stack.NICOptions{Name: Host1NICName} + if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil { + t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err) + { + opts := stack.NICOptions{Name: RouterNIC1Name} + if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err) + { + opts := stack.NICOptions{Name: RouterNIC2Name} + if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err) + } } - if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) + { + opts := stack.NICOptions{Name: Host2NICName} + if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil { + t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err) + } } if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -231,29 +251,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err) } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err) } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err) } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err) } host1Stack.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index bbc0e3ecc..4718ec4ec 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/waiter", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 00497bf07..995f58616 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,7 @@ package icmp import ( + "fmt" "io" "time" @@ -24,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -35,15 +38,6 @@ type icmpPacket struct { receivedAt time.Time `state:".(int64)"` } -type endpointState int - -const ( - stateInitial endpointState = iota - stateBound - stateConnected - stateClosed -) - // endpoint represents an ICMP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -51,14 +45,17 @@ const ( // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -70,38 +67,23 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - state endpointState - route *stack.Route `state:"manual"` - ttl uint8 - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool + ident uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, - state: stateInitial, uniqueID: s.UniqueID(), } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) + ep.net.Init(s, netProto, transProto, &ep.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -128,35 +110,40 @@ func (e *endpoint) Abort() { // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { - e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { - case stateBound, stateConnected: - bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) - } - - // Close the receive list and drain it. - e.rcvMu.Lock() - e.rcvClosed = true - e.rcvBufSize = 0 - for !e.rcvList.Empty() { - p := e.rcvList.Front() - e.rcvList.Remove(p) - } - e.rcvMu.Unlock() + notify := func() bool { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + return false + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + info := e.net.Info() + info.ID.LocalPort = e.ident + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice())) + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } - if e.route != nil { - e.route.Release() - e.route = nil - } + e.net.Shutdown() + e.net.Close() - // Update the state. - e.state = stateClosed + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } - e.mu.Unlock() + return true + }() - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + if notify { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + } } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. @@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -193,7 +180,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: p.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.receivedAt.UnixNano(), + Timestamp: p.receivedAt, }, } if opts.NeedRemoteAddr { @@ -213,14 +200,13 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -// +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.state { - case stateInitial: - case stateConnected: +// +checklocksread:e.mu +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - - case stateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { - return 0, err + return network.WriteContext{}, 0, err } if !retry { @@ -298,36 +272,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } } - route := e.route - if to != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := to.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - - dst, netProto, err := e.checkV4MappedLocked(*to) - if err != nil { - return 0, err - } - - // Find the endpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) - if err != nil { - return 0, err - } - defer r.Release() + ctx, err := e.net.AcquireContextForWrite(opts) + return ctx, e.ident, err +} - route = r +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + ctx, ident, err := e.prepareForWrite(opts) + if err != nil { + return 0, err } + defer ctx.Release() // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) @@ -335,17 +289,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return 0, &tcpip.ErrBadBuffer{} } - var err tcpip.Error - switch e.NetProto { + switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) + if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } case header.IPv6ProtocolNumber: - err = send6(route, e.ID.LocalPort, v, e.ttl) - } - - if err != nil { - return 0, err + if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } + default: + panic(fmt.Sprintf("unhandled network protocol = %d", netProto)) } return int64(len(v)), nil @@ -358,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool { return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt sets a socket option. -func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return nil +// SetSockOpt implements tcpip.Endpoint. +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + return e.net.SetSockOpt(opt) } -// SetSockOptInt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - } - return nil + return e.net.SetSockOptInt(opt, v) } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -388,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.rcvMu.Lock() - v := int(e.ttl) - e.rcvMu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +// GetSockOpt implements tcpip.Endpoint. +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { +func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength), }) - pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber @@ -427,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest - icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - pkt.Data().AppendView(data) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V4.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { +func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength), }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) @@ -469,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest pkt.Data().AppendView(data) dataRange := pkt.Data().AsRange() icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpv6, - Src: r.LocalAddress(), - Dst: r.RemoteAddress(), + Src: src, + Dst: dst, PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V6.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() + return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() tcpip.Error { return &tcpip.ErrNotSupported{} @@ -516,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - localPort := uint16(0) - switch e.state { - case stateInitial: - case stateBound, stateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.ident - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + nextID, err := e.registerWithStack(netProto, nextID) + if err != nil { + return err } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } - - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress(), - LocalPort: localPort, - RemoteAddress: r.RemoteAddress(), - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - id, err = e.registerWithStack(nicID, netProtos, id) + e.ident = nextID.LocalPort + return nil + }) if err != nil { - r.Release() return err } - e.ID = id - e.route = r - e.RegisterNICID = nicID - - e.state = stateConnected - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -586,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - e.shutdownFlags |= flags - if e.state != stateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } + + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } } if flags&tcpip.ShutdownRead != 0 { @@ -616,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) - return id, err + return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) + err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -645,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err := e.registerWithStack(boundNetProto, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, err = e.registerWithStack(addr.NIC, netProtos, id) + e.ident = id.LocalPort + return nil + }) if err != nil { return err } - e.ID = id - e.RegisterNICID = addr.NIC - - // Mark endpoint as bound. - e.state = stateBound - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -688,21 +567,24 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { return nil } +func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || + header.IsV4MulticastAddress(addr) || + header.IsV6MulticastAddress(addr) || + e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr) +} + // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - err := e.bindLocked(addr) - if err != nil { - return err + if len(addr.Addr) != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) { + return &tcpip.ErrBadLocalAddress{} } - e.BindNICID = addr.NIC - e.BindAddr = addr.Addr + e.mu.Lock() + defer e.mu.Unlock() - return nil + return e.bindLocked(addr) } // GetLocalAddress returns the address to which the endpoint is bound. @@ -710,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.ident + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -722,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { - return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} + if addr, connected := e.net.GetRemoteAddress(); connected { + return addr, nil } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness returns the current readiness of the endpoint. For example, if @@ -755,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. - switch e.NetProto { + switch e.net.NetProto() { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { @@ -829,9 +705,9 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + defer e.mu.RUnlock() + ret := e.net.Info() + ret.ID.LocalPort = e.ident return &ret } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index b8b839e4a..dfe453ff9 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,13 @@ package icmp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.thaw() + + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - if e.state != stateBound && e.state != stateConnected { - return - } - - var err tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + var err tcpip.Error + info := e.net.Info() + info.ID.LocalPort = e.ident + info.ID, err = e.registerWithStack(info.NetProto, info.ID) if err != nil { - panic(err) + panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err)) } - - e.ID.LocalAddress = e.route.LocalAddress() - } else if len(e.ID.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) - if err != nil { - panic(err) + e.ident = info.ID.LocalPort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go index cc950cbde..729f50e9a 100644 --- a/pkg/tcpip/transport/icmp/icmp_test.go +++ b/pkg/tcpip/transport/icmp/icmp_test.go @@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s t.Fatalf("s.CreateNIC(%d, _) = %s", id, err) } - if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) } s.AddRoute(tcpip.Route{ diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD index b1edce39b..3818cb04e 100644 --- a/pkg/tcpip/transport/internal/network/BUILD +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -9,6 +9,7 @@ go_library( "endpoint_state.go", ], visibility = [ + "//pkg/tcpip/transport/icmp:__pkg__", "//pkg/tcpip/transport/raw:__pkg__", "//pkg/tcpip/transport/udp:__pkg__", ], diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 09b629022..fb31e5104 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -38,31 +38,65 @@ type Endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - // state holds a transport.DatagramBasedEndpointState. - // - // state must be read from/written to atomically. - state uint32 - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu wasBound bool - info stack.TransportEndpointInfo // owner is the owner of transmitted packets. - owner tcpip.PacketOwner - writeShutdown bool - effectiveNetProto tcpip.NetworkProtocolNumber - connectedRoute *stack.Route `state:"manual"` + // + // +checklocks:mu + owner tcpip.PacketOwner + // +checklocks:mu + writeShutdown bool + // +checklocks:mu + effectiveNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu + connectedRoute *stack.Route `state:"manual"` + // +checklocks:mu multicastMemberships map[multicastMembership]struct{} // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu ttl uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastTTL uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastAddr tcpip.Address // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastNICID tcpip.NICID - ipv4TOS uint8 - ipv6TClass uint8 + // +checklocks:mu + ipv4TOS uint8 + // +checklocks:mu + ipv6TClass uint8 + + // Lock ordering: mu > infoMu. + infoMu sync.RWMutex `state:"nosave"` + // info has a dedicated mutex so that we can avoid lock ordering violations + // when reading the endpoint's info. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling Info() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setInfo. + // + // +checklocks:infoMu + info stack.TransportEndpointInfo + + // state holds a transport.DatagramBasedEndpointState. + // + // state must be accessed with atomics so that we can avoid lock ordering + // violations when reading the state. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling State() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setEndpointState. + // + // +checkatomics + state uint32 } // +stateify savable @@ -73,8 +107,11 @@ type multicastMembership struct { // Init initializes the endpoint. func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) { - if e.multicastMemberships != nil { - panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships)) + e.mu.Lock() + memberships := e.multicastMemberships + e.mu.Unlock() + if memberships != nil { + panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships)) } switch netProto { @@ -89,8 +126,6 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr netProto: netProto, transProto: transProto, - state: uint32(transport.DatagramEndpointStateInitial), - info: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: transProto, @@ -100,6 +135,10 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), } + + e.mu.Lock() + defer e.mu.Unlock() + e.setEndpointState(transport.DatagramEndpointStateInitial) } // NetProto returns the network protocol the endpoint was initialized with. @@ -107,7 +146,12 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { return e.netProto } -// setState sets the state of the endpoint. +// setEndpointState sets the state of the endpoint. +// +// e.mu must be held to synchronize changes to state with the rest of the +// endpoint. +// +// +checklocks:e.mu func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { atomic.StoreUint32(&e.state, uint32(state)) } @@ -242,23 +286,24 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if nicID == 0 { nicID = tcpip.NICID(e.ops.GetBindToDevice()) } - if e.info.BindNICID != 0 { - if nicID != 0 && nicID != e.info.BindNICID { + info := e.Info() + if info.BindNICID != 0 { + if nicID != 0 && nicID != info.BindNICID { return WriteContext{}, &tcpip.ErrNoRoute{} } - nicID = e.info.BindNICID + nicID = info.BindNICID } if nicID == 0 { - nicID = e.info.RegisterNICID + nicID = info.RegisterNICID } - dst, netProto, err := e.checkV4MappedLocked(*opts.To) + dst, netProto, err := e.checkV4Mapped(*opts.To) if err != nil { return WriteContext{}, err } - route, _, err = e.connectRoute(nicID, dst, netProto) + route, _, err = e.connectRouteRLocked(nicID, dst, netProto) if err != nil { return WriteContext{}, err } @@ -297,26 +342,30 @@ func (e *Endpoint) Disconnect() { return } + info := e.Info() // Exclude ephemerally bound endpoints. if e.wasBound { - e.info.ID = stack.TransportEndpointID{ - LocalAddress: e.info.BindAddr, + info.ID = stack.TransportEndpointID{ + LocalAddress: info.BindAddr, } e.setEndpointState(transport.DatagramEndpointStateBound) } else { - e.info.ID = stack.TransportEndpointID{} + info.ID = stack.TransportEndpointID{} e.setEndpointState(transport.DatagramEndpointStateInitial) } + e.setInfo(info) e.connectedRoute.Release() e.connectedRoute = nil } -// connectRoute establishes a route to the specified interface or the +// connectRouteRLocked establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { - localAddr := e.info.ID.LocalAddress +// +// +checklocksread:e.mu +func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { + localAddr := e.Info().ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" @@ -359,42 +408,43 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.mu.Lock() defer e.mu.Unlock() + info := e.Info() nicID := addr.NIC switch e.State() { case transport.DatagramEndpointStateInitial: case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: - if e.info.BindNICID == 0 { + if info.BindNICID == 0 { break } - if nicID != 0 && nicID != e.info.BindNICID { + if nicID != 0 && nicID != info.BindNICID { return &tcpip.ErrInvalidEndpointState{} } - nicID = e.info.BindNICID + nicID = info.BindNICID default: return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4Mapped(addr) if err != nil { return err } - r, nicID, err := e.connectRoute(nicID, addr, netProto) + r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto) if err != nil { return err } id := stack.TransportEndpointID{ - LocalAddress: e.info.ID.LocalAddress, + LocalAddress: info.ID.LocalAddress, RemoteAddress: r.RemoteAddress(), } if e.State() == transport.DatagramEndpointStateInitial { id.LocalAddress = r.LocalAddress() } - if err := f(r.NetProto(), e.info.ID, id); err != nil { + if err := f(r.NetProto(), info.ID, id); err != nil { return err } @@ -403,8 +453,9 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.connectedRoute.Release() } e.connectedRoute = r - e.info.ID = id - e.info.RegisterNICID = nicID + info.ID = id + info.RegisterNICID = nicID + e.setInfo(info) e.effectiveNetProto = netProto e.setEndpointState(transport.DatagramEndpointStateConnected) return nil @@ -426,10 +477,11 @@ func (e *Endpoint) Shutdown() tcpip.Error { } } -// checkV4MappedLocked determines the effective network protocol and converts +// checkV4MappedRLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) +func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { + info := e.Info() + unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -464,7 +516,7 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4Mapped(addr) if err != nil { return err } @@ -483,12 +535,14 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto e.wasBound = true - e.info.ID = stack.TransportEndpointID{ + info := e.Info() + info.ID = stack.TransportEndpointID{ LocalAddress: addr.Addr, } - e.info.BindNICID = addr.NIC - e.info.RegisterNICID = nicID - e.info.BindAddr = addr.Addr + info.BindNICID = addr.NIC + info.RegisterNICID = nicID + info.BindAddr = addr.Addr + e.setInfo(info) e.effectiveNetProto = netProto e.setEndpointState(transport.DatagramEndpointStateBound) return nil @@ -506,13 +560,14 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { e.mu.RLock() defer e.mu.RUnlock() - addr := e.info.BindAddr + info := e.Info() + addr := info.BindAddr if e.State() == transport.DatagramEndpointStateConnected { addr = e.connectedRoute.LocalAddress() } return tcpip.FullAddress{ - NIC: e.info.RegisterNICID, + NIC: info.RegisterNICID, Addr: addr, } } @@ -528,7 +583,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { return tcpip.FullAddress{ Addr: e.connectedRoute.RemoteAddress(), - NIC: e.info.RegisterNICID, + NIC: e.Info().RegisterNICID, }, true } @@ -610,7 +665,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedLocked(fa) + fa, netProto, err := e.checkV4Mapped(fa) if err != nil { return err } @@ -634,7 +689,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } - if e.info.BindNICID != 0 && e.info.BindNICID != nic { + if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic { return &tcpip.ErrInvalidEndpointState{} } @@ -737,7 +792,19 @@ func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { // Info returns a copy of the endpoint info. func (e *Endpoint) Info() stack.TransportEndpointInfo { - e.mu.RLock() - defer e.mu.RUnlock() + e.infoMu.RLock() + defer e.infoMu.RUnlock() return e.info } + +// setInfo sets the endpoint's info. +// +// e.mu must be held to synchronize changes to info with the rest of the +// endpoint. +// +// +checklocks:e.mu +func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) { + e.infoMu.Lock() + defer e.infoMu.Unlock() + e.info = info +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go index 858007156..68bd1fbf6 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_state.go +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -35,20 +35,22 @@ func (e *Endpoint) Resume(s *stack.Stack) { } } + info := e.Info() + switch state := e.State(); state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: case transport.DatagramEndpointStateBound: - if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { - if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 { - panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress)) + if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) { + if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 { + panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress)) } } case transport.DatagramEndpointStateConnected: var err tcpip.Error multicastLoop := e.ops.GetMulticastLoop() - e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) + e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) if err != nil { - panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) + panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) } default: panic(fmt.Sprintf("unhandled state = %s", state)) diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go index d99c961c3..f263a9ea2 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_test.go +++ b/pkg/tcpip/transport/internal/network/endpoint_test.go @@ -124,11 +124,20 @@ func TestEndpointStateTransitions(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err) + ipv4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4NICAddr.WithPrefix(), } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err) + } + ipv6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: ipv6NICAddr.WithPrefix(), + } + + if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -257,11 +266,19 @@ func TestBindNICID(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err) + ipv4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4NICAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err) + } + ipv6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: ipv6NICAddr.WithPrefix(), } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err) + if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err) } var ops tcpip.SocketOptions diff --git a/pkg/tcpip/transport/internal/noop/BUILD b/pkg/tcpip/transport/internal/noop/BUILD new file mode 100644 index 000000000..171c41eb1 --- /dev/null +++ b/pkg/tcpip/transport/internal/noop/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "noop", + srcs = ["endpoint.go"], + visibility = ["//pkg/tcpip/transport/raw:__pkg__"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/internal/noop/endpoint.go b/pkg/tcpip/transport/internal/noop/endpoint.go new file mode 100644 index 000000000..443b4e416 --- /dev/null +++ b/pkg/tcpip/transport/internal/noop/endpoint.go @@ -0,0 +1,172 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package noop contains an endpoint that implements all tcpip.Endpoint +// functions as noops. +package noop + +import ( + "fmt" + "io" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// endpoint can be created, but all interactions have no effect or +// return errors. +// +// +stateify savable +type endpoint struct { + tcpip.DefaultSocketOptionsHandler + ops tcpip.SocketOptions +} + +// New returns an initialized noop endpoint. +func New(stk *stack.Stack) tcpip.Endpoint { + // ep.ops must be in a valid, initialized state for callers of + // ep.SocketOptions. + var ep endpoint + ep.ops.InitHandler(&ep, stk, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) + return &ep +} + +// Abort implements stack.TransportEndpoint.Abort. +func (*endpoint) Abort() { + // No-op. +} + +// Close implements tcpip.Endpoint.Close. +func (*endpoint) Close() { + // No-op. +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (*endpoint) ModerateRecvBuf(int) { + // No-op. +} + +func (*endpoint) SetOwner(tcpip.PacketOwner) { + // No-op. +} + +// Read implements tcpip.Endpoint.Read. +func (*endpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { + return tcpip.ReadResult{}, &tcpip.ErrNotPermitted{} +} + +// Write implements tcpip.Endpoint.Write. +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { + return 0, &tcpip.ErrNotPermitted{} +} + +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +// Connect implements tcpip.Endpoint.Connect. +func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error { + return &tcpip.ErrNotPermitted{} +} + +// Shutdown implements tcpip.Endpoint.Shutdown. +func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { + return &tcpip.ErrNotPermitted{} +} + +// Listen implements tcpip.Endpoint.Listen. +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +// Accept implements tcpip.Endpoint.Accept. +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} +} + +// Bind implements tcpip.Endpoint.Bind. +func (*endpoint) Bind(tcpip.FullAddress) tcpip.Error { + return &tcpip.ErrNotPermitted{} +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (*endpoint) Readiness(waiter.EventMask) waiter.EventMask { + return 0 +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} +} + +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (*endpoint) GetSockOptInt(tcpip.SockOptInt) (int, tcpip.Error) { + return 0, &tcpip.ErrUnknownProtocolOption{} +} + +// HandlePacket implements stack.RawTransportEndpoint.HandlePacket. +func (*endpoint) HandlePacket(pkt *stack.PacketBuffer) { + panic(fmt.Sprintf("unreachable: noop.endpoint should never be registered, but got packet: %+v", pkt)) +} + +// State implements socket.Socket.State. +func (*endpoint) State() uint32 { + return 0 +} + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() { + // No-op. +} + +// LastError implements tcpip.Endpoint.LastError. +func (*endpoint) LastError() tcpip.Error { + return nil +} + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { + return &ep.ops +} + +// Info implements tcpip.Endpoint.Info. +func (*endpoint) Info() tcpip.EndpointInfo { + return &stack.TransportEndpointInfo{} +} + +// Stats returns a pointer to the endpoint stats. +func (*endpoint) Stats() tcpip.EndpointStats { + return &tcpip.TransportEndpointStats{} +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0554d2f4a..80eef39e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -59,52 +59,47 @@ type packet struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue cooked bool - - // The following fields are used to manage the receive queue and are - // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList + ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats + + // The following fields are used to manage the receive queue. + rcvMu sync.Mutex `state:"nosave"` + // +checklocks:rcvMu + rcvList packetList + // +checklocks:rcvMu rcvBufSize int - rcvClosed bool - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + // +checklocks:rcvMu + rcvClosed bool + // +checklocks:rcvMu + rcvDisabled bool + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + closed bool + // +checklocks:mu + boundNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu boundNIC tcpip.NICID - // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` - lastError tcpip.Error - - // ops is used to get socket level options. - ops tcpip.SocketOptions - - // frozen indicates if the packets should be delivered to the endpoint - // during restore. - frozen bool + // +checklocks:lastErrorMu + lastError tcpip.Error } // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - }, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, + stack: s, + cooked: cooked, + boundNetProto: netProto, + waiterQueue: waiterQueue, } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) @@ -140,7 +135,7 @@ func (ep *endpoint) Close() { return } - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) ep.rcvMu.Lock() defer ep.rcvMu.Unlock() @@ -153,7 +148,6 @@ func (ep *endpoint) Close() { } ep.closed = true - ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -188,7 +182,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul Total: packet.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: packet.receivedAt.UnixNano(), + Timestamp: packet.receivedAt, }, } if opts.NeedRemoteAddr { @@ -214,13 +208,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc ep.mu.Lock() closed := ep.closed nicID := ep.boundNIC + proto := ep.boundNetProto ep.mu.Unlock() if closed { return 0, &tcpip.ErrClosedForSend{} } var remote tcpip.LinkAddress - proto := ep.netProto if to := opts.To; to != nil { remote = tcpip.LinkAddress(to.Addr) @@ -296,29 +290,42 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound && ep.boundNIC == addr.NIC { - // If the NIC being bound is the same then just return success. + netProto := tcpip.NetworkProtocolNumber(addr.Port) + if netProto == 0 { + // Do not allow unbinding the network protocol. + netProto = ep.boundNetProto + } + + if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto { + // Already bound to the requested NIC and network protocol. return nil } - // Unregister endpoint with all the nics. - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) - ep.bound = false + // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new + // binding. + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) + ep.boundNIC = 0 + ep.boundNetProto = 0 // Bind endpoint to receive packets from specific interface. - if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } - ep.bound = true ep.boundNIC = addr.NIC - + ep.boundNetProto = netProto return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + ep.mu.RLock() + defer ep.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: ep.boundNIC, + Port: uint16(ep.boundNetProto), + }, nil } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. @@ -402,7 +409,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, _ tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -414,7 +421,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } rcvBufSize := ep.ops.GetReceiveBufferSize() - if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { + if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -473,10 +480,8 @@ func (*endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() - // Make a copy of the endpoint info. - ret := ep.TransportEndpointInfo - ep.mu.RUnlock() - return &ret + defer ep.mu.RUnlock() + return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto} } // Stats returns a pointer to the endpoint stats. @@ -491,18 +496,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } - -// freeze prevents any more packets from being delivered to the endpoint. -func (ep *endpoint) freeze() { - ep.mu.Lock() - ep.frozen = true - ep.mu.Unlock() -} - -// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows -// new packets to be delivered again. -func (ep *endpoint) thaw() { - ep.mu.Lock() - ep.frozen = false - ep.mu.Unlock() -} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 5c688d286..88cd80ad3 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -44,17 +45,24 @@ func (p *packet) loadData(data buffer.VectorisedView) { // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - ep.freeze() + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + ep.rcvDisabled = true } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - ep.thaw() + ep.mu.Lock() + defer ep.mu.Unlock() + ep.stack = stack.StackFromEnv ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. - if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(err) + if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil { + panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err)) } + + ep.rcvMu.Lock() + ep.rcvDisabled = false + ep.rcvMu.Unlock() } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index b7e97e218..10b0c35fb 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -35,6 +35,7 @@ go_library( "//pkg/tcpip/stack", "//pkg/tcpip/transport", "//pkg/tcpip/transport/internal/network", + "//pkg/tcpip/transport/internal/noop", "//pkg/tcpip/transport/packet", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 3040a445b..ce76774af 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -49,6 +49,7 @@ type rawPacket struct { receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + packetInfo tcpip.IPPacketInfo } // endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to @@ -70,7 +71,7 @@ type endpoint struct { associated bool net network.Endpoint - stats tcpip.TransportEndpointStats `state:"nosave"` + stats tcpip.TransportEndpointStats ops tcpip.SocketOptions // The following fields are used to manage the receive queue and are @@ -202,12 +203,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: pkt.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: pkt.receivedAt.UnixNano(), + Timestamp: pkt.receivedAt, }, } if opts.NeedRemoteAddr { res.RemoteAddr = pkt.senderAddr } + switch netProto := e.net.NetProto(); netProto { + case header.IPv4ProtocolNumber: + if e.ops.GetReceivePacketInfo() { + res.ControlMessages.HasIPPacketInfo = true + res.ControlMessages.PacketInfo = pkt.packetInfo + } + case header.IPv6ProtocolNumber: + if e.ops.GetIPv6ReceivePacketInfo() { + res.ControlMessages.HasIPv6PacketInfo = true + res.ControlMessages.IPv6PacketInfo = tcpip.IPv6PacketInfo{ + NIC: pkt.packetInfo.NIC, + Addr: pkt.packetInfo.DestinationAddr, + } + } + default: + panic(fmt.Sprintf("unrecognized network protocol = %d", netProto)) + } n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { @@ -435,7 +453,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return false } - srcAddr := pkt.Network().SourceAddress() + net := pkt.Network() + dstAddr := net.DestinationAddress() + srcAddr := net.SourceAddress() info := e.net.Info() switch state := e.net.State(); state { @@ -457,7 +477,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } // If bound to an address, only accept data for that address. - if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() { + if info.BindAddr != "" && info.BindAddr != dstAddr { return false } default: @@ -472,6 +492,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { NIC: pkt.NICID, Addr: srcAddr, }, + packetInfo: tcpip.IPPacketInfo{ + // TODO(gvisor.dev/issue/3556): dstAddr may be a multicast or broadcast + // address. LocalAddr should hold a unicast address that can be + // used to respond to the incoming packet. + LocalAddr: dstAddr, + DestinationAddr: dstAddr, + NIC: pkt.NICID, + }, } // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. @@ -483,10 +511,10 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // overlapping slices. var combinedVV buffer.VectorisedView if info.NetProto == header.IPv4ProtocolNumber { - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - headers := make(buffer.View, 0, len(network)+len(transport)) - headers = append(headers, network...) - headers = append(headers, transport...) + networkHeader, transportHeader := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(networkHeader)+len(transportHeader)) + headers = append(headers, networkHeader...) + headers = append(headers, transportHeader...) combinedVV = headers.ToVectorisedView() } else { combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index e393b993d..624e2dbe7 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -17,6 +17,7 @@ package raw import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/noop" "gvisor.dev/gvisor/pkg/tcpip/transport/packet" "gvisor.dev/gvisor/pkg/waiter" ) @@ -33,3 +34,18 @@ func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpi func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) } + +// CreateOnlyFactory implements stack.RawFactory. It allows creation of raw +// endpoints that do not support reading, writing, binding, etc. +type CreateOnlyFactory struct{} + +// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. +func (CreateOnlyFactory) NewUnassociatedEndpoint(stk *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return noop.New(stk), nil +} + +// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. +func (CreateOnlyFactory) NewPacketEndpoint(*stack.Stack, bool, tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + // This isn't needed by anything, so it isn't implemented. + return nil, &tcpip.ErrNotPermitted{} +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 5148fe157..20958d882 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -80,9 +80,10 @@ go_library( go_test( name = "tcp_x_test", - size = "medium", + size = "large", srcs = [ "dual_stack_test.go", + "rcv_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", "tcp_rack_test.go", @@ -114,16 +115,6 @@ go_test( ) go_test( - name = "rcv_test", - size = "small", - srcs = ["rcv_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( name = "tcp_test", size = "small", srcs = [ diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 03c9fafa1..caf14b0dc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -15,12 +15,12 @@ package tcp import ( + "container/list" "crypto/sha1" "encoding/binary" "fmt" "hash" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -100,18 +100,6 @@ type listenContext struct { // netProto indicates the network protocol(IPv4/v6) for the listening // endpoint. netProto tcpip.NetworkProtocolNumber - - // pendingMu protects pendingEndpoints. This should only be accessed - // by the listening endpoint's worker goroutine. - // - // Lock Ordering: listenEP.workerMu -> pendingMu - pendingMu sync.Mutex - // pending is used to wait for all pendingEndpoints to finish when - // a socket is closed. - pending sync.WaitGroup - // pendingEndpoints is a map of all endpoints for which a handshake is - // in progress. - pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -122,14 +110,13 @@ func timeStamp(clock tcpip.Clock) uint32 { // newListenContext creates a new listen context. func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stk, - protocol: protocol, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6Only: v6Only, - netProto: netProto, - listenEP: listenEP, - pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), + stack: stk, + protocol: protocol, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6Only: v6Only, + netProto: netProto, + listenEP: listenEP, } for i := range l.nonce { @@ -193,14 +180,6 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } -func (l *listenContext) useSynCookies() bool { - var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies - if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { - panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) - } - return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) -} - // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { @@ -273,18 +252,15 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu return nil, &tcpip.ErrConnectionAborted{} } - l.addPendingEndpoint(ep) // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. - l.listenEP.propagateInheritableOptionsLocked(ep) + l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce if !ep.reserveTupleLocked() { ep.mu.Unlock() ep.Close() - l.removePendingEndpoint(ep) - return nil, &tcpip.ErrConnectionAborted{} } @@ -303,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - ep.drainClosingSegmentQueue() return nil, err @@ -344,39 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, return ep, nil } -func (l *listenContext) addPendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - l.pendingEndpoints[n.TransportEndpointInfo.ID] = n - l.pending.Add(1) - l.pendingMu.Unlock() -} - -func (l *listenContext) removePendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) - l.pending.Done() - l.pendingMu.Unlock() -} - -func (l *listenContext) closeAllPendingEndpoints() { - l.pendingMu.Lock() - for _, n := range l.pendingEndpoints { - n.notifyProtocolGoroutine(notifyClose) - } - l.pendingMu.Unlock() - l.pending.Wait() -} - -// Precondition: h.ep.mu must be held. // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() e.Close() e.notifyAborted() - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.drainClosingSegmentQueue() e.h = nil } @@ -384,12 +329,9 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // cleanupCompletedHandshake transfers any state from the completed handshake to // the new endpoint. // -// Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.isConnectNotified = true // Update the receive window scaling. We can't do it before the @@ -401,47 +343,11 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e.h = nil } -// deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// listener has transitioned out of the listen state (accepted is the zero -// value), the new endpoint is reset instead. -func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { - e.mu.Lock() - e.pendingAccepted.Add(1) - e.mu.Unlock() - defer e.pendingAccepted.Done() - - // Drop the lock before notifying to avoid deadlock in user-specified - // callbacks. - delivered := func() bool { - e.acceptMu.Lock() - defer e.acceptMu.Unlock() - for { - if e.accepted == (accepted{}) { - return false - } - if e.accepted.endpoints.Len() == e.accepted.cap { - e.acceptCond.Wait() - continue - } - - e.accepted.endpoints.PushBack(n) - if !withSynCookie { - atomic.AddInt32(&e.synRcvdCount, -1) - } - return true - } - }() - if delivered { - e.waiterQueue.Notify(waiter.ReadableEvents) - } else { - n.notifyProtocolGoroutine(notifyReset) - } -} - // propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. // -// Precondition: e.mu and n.mu must be held. +// +checklocks:e.mu +// +checklocks:n.mu func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout n.portFlags = e.portFlags @@ -452,9 +358,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // reserveTupleLocked reserves an accepted endpoint's tuple. // -// Preconditions: -// * propagateInheritableOptionsLocked has been called. -// * e.mu is held. +// Precondition: e.propagateInheritableOptionsLocked has been called. +// +// +checklocks:e.mu func (e *endpoint) reserveTupleLocked() bool { dest := tcpip.FullAddress{ Addr: e.TransportEndpointInfo.ID.RemoteAddress, @@ -489,70 +395,36 @@ func (e *endpoint) notifyAborted() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -// handleSynSegment is called in its own goroutine once the listening endpoint -// receives a SYN segment. It is responsible for completing the handshake and -// queueing the new endpoint for acceptance. -// -// A limited number of these goroutines are allowed before TCP starts using SYN -// cookies to accept connections. -// -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error { - defer s.decRef() +func (e *endpoint) acceptQueueIsFull() bool { + e.acceptMu.Lock() + full := e.acceptQueue.isFull() + e.acceptMu.Unlock() + return full +} - h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) - if err != nil { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - atomic.AddInt32(&e.synRcvdCount, -1) - return err - } +// +stateify savable +type acceptQueue struct { + // NB: this could be an endpointList, but ilist only permits endpoints to + // belong to one list at a time, and endpoints are already stored in the + // dispatcher's list. + endpoints list.List `state:".([]*endpoint)"` - go func() { - // Note that startHandshake returns a locked endpoint. The - // force call here just makes it so. - if err := h.complete(); err != nil { // +checklocksforce - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - ctx.cleanupFailedHandshake(h) - atomic.AddInt32(&e.synRcvdCount, -1) - return - } - ctx.cleanupCompletedHandshake(h) - h.ep.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep, false /*withSynCookie*/) - }() + // pendingEndpoints is a set of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[*endpoint]struct{} - return nil + // capacity is the maximum number of endpoints that can be in endpoints. + capacity int } -func (e *endpoint) synRcvdBacklogFull() bool { - e.acceptMu.Lock() - acceptedCap := e.accepted.cap - e.acceptMu.Unlock() - // The capacity of the accepted queue would always be one greater than the - // listen backlog. But, the SYNRCVD connections count is always checked - // against the listen backlog value for Linux parity reason. - // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 - // - // We maintain an equality check here as the synRcvdCount is incremented - // and compared only from a single listener context and the capacity of - // the accepted queue can only increase by a new listen call. - return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 -} - -func (e *endpoint) acceptQueueIsFull() bool { - e.acceptMu.Lock() - full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap - e.acceptMu.Unlock() - return full +func (a *acceptQueue) isFull() bool { + return a.endpoints.Len() == a.capacity } // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed @@ -580,11 +452,95 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } opts := parseSynSegmentOptions(s) - if !ctx.useSynCookies() { - s.incRef() - atomic.AddInt32(&e.synRcvdCount, 1) - return e.handleSynSegment(ctx, s, opts) + + useSynCookies, err := func() (bool, tcpip.Error) { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + if alwaysUseSynCookies { + return true, nil + } + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 { + return true, nil + } + + h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) + if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return false, err + } + + e.acceptQueue.pendingEndpoints[h.ep] = struct{}{} + e.pendingAccepted.Add(1) + + go func() { + defer func() { + e.pendingAccepted.Done() + + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + delete(e.acceptQueue.pendingEndpoints, h.ep) + }() + + // Note that startHandshake returns a locked endpoint. The force call + // here just makes it so. + if err := h.complete(); err != nil { // +checklocksforce + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + ctx.cleanupFailedHandshake(h) + return + } + ctx.cleanupCompletedHandshake(h) + h.ep.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + + // Deliver the endpoint to the accept queue. + // + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + // The listener is transitioning out of the Listen state; bail. + if e.acceptQueue.capacity == 0 { + return false + } + if e.acceptQueue.isFull() { + e.acceptCond.Wait() + continue + } + + e.acceptQueue.endpoints.PushBack(h.ep) + return true + } + }() + + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + h.ep.notifyProtocolGoroutine(notifyReset) + } + }() + + return false, nil + }() + if err != nil { + return err + } + if !useSynCookies { + return nil } + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err @@ -627,18 +583,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil case s.flags.Contains(header.TCPFlagAck): - if e.acceptQueueIsFull() { - // Silently drop the ack as the application can't accept - // the connection at this point. The ack will be - // retransmitted by the sender anyway and we can - // complete the connection at the time of retransmit if - // the backlog has space. - e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() - e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() - e.stack.Stats().DroppedPackets.Increment() - return nil - } - iss := s.ackNumber - 1 irs := s.sequenceNumber - 1 @@ -674,6 +618,24 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // ACK was received from the sender. return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } + + // Keep hold of acceptMu until the new endpoint is in the accept queue (or + // if there is an error), to guarantee that we will keep our spot in the + // queue even if another handshake from the syn queue completes. + e.acceptMu.Lock() + if e.acceptQueue.isFull() { + // Silently drop the ack as the application can't accept + // the connection at this point. The ack will be + // retransmitted by the sender anyway and we can + // complete the connection at the time of retransmit if + // the backlog has space. + e.acceptMu.Unlock() + e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return nil + } + e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. rcvdSynOptions := header.TCPSynOptions{ @@ -695,6 +657,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { + e.acceptMu.Unlock() return err } @@ -706,6 +669,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !n.reserveTupleLocked() { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -723,6 +687,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.boundBindToDevice, ); err != nil { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -755,20 +720,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.newSegmentWaker.Assert() } - // Do the delivery in a separate goroutine so - // that we don't block the listen loop in case - // the application is slow to accept or stops - // accepting. - // - // NOTE: This won't result in an unbounded - // number of goroutines as we do check before - // entering here that there was at least some - // space available in the backlog. - // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n, true /*withSynCookie*/) + + // Deliver the endpoint to the accept queue. + e.acceptQueue.endpoints.PushBack(n) + e.acceptMu.Unlock() + + e.waiterQueue.Notify(waiter.ReadableEvents) return nil default: @@ -785,14 +745,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto) defer func() { - // Mark endpoint as closed. This will prevent goroutines running - // handleSynSegment() from attempting to queue new connections - // to the endpoint. e.setEndpointState(StateClose) - // Close any endpoints in SYN-RCVD state. - ctx.closeAllPendingEndpoints() - // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5d8e18484..80cd07218 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -30,6 +30,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// InitialRTO is the initial retransmission timeout. +// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142 +const InitialRTO = time.Second + // maxSegmentsPerWake is the maximum number of segments to process in the main // protocol goroutine per wake-up. Yielding [after this number of segments are // processed] allows other events to be processed as well (e.g., timeouts, @@ -532,7 +536,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -578,6 +582,9 @@ func (h *handshake) complete() tcpip.Error { if (n¬ifyClose)|(n¬ifyAbort) != 0 { return &tcpip.ErrAborted{} } + if n¬ifyShutdown != 0 { + return &tcpip.ErrConnectionReset{} + } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { s := h.ep.segmentQueue.dequeue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d2b8f298f..066ffe051 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,7 +15,6 @@ package tcp import ( - "container/list" "encoding/binary" "fmt" "io" @@ -187,6 +186,8 @@ const ( // say TIME_WAIT. notifyTickleWorker notifyError + // notifyShutdown means that a connecting socket was shutdown. + notifyShutdown ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -203,6 +204,8 @@ type SACKInfo struct { } // ReceiveErrors collect segment receive errors within transport layer. +// +// +stateify savable type ReceiveErrors struct { tcpip.ReceiveErrors @@ -232,6 +235,8 @@ type ReceiveErrors struct { } // SendErrors collect segment send errors within the transport layer. +// +// +stateify savable type SendErrors struct { tcpip.SendErrors @@ -255,6 +260,8 @@ type SendErrors struct { } // Stats holds statistics about the endpoint. +// +// +stateify savable type Stats struct { // SegmentsReceived is the number of TCP segments received that // the transport layer successfully parsed. @@ -309,15 +316,6 @@ type rcvQueueInfo struct { rcvQueue segmentList `state:"wait"` } -// +stateify savable -type accepted struct { - // NB: this could be an endpointList, but ilist only permits endpoints to - // belong to one list at a time, and endpoints are already stored in the - // dispatcher's list. - endpoints list.List `state:".([]*endpoint)"` - cap int -} - // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -333,7 +331,7 @@ type accepted struct { // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects accepted. +// e.acceptMu -> Protects e.acceptQueue. // e.rcvQueueMu -> Protects e.rcvQueue and associated fields. // e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. @@ -497,10 +495,6 @@ type endpoint struct { // and dropped when it is. segmentQueue segmentQueue `state:"wait"` - // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state; this is only accessed atomically. - synRcvdCount int32 - // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 @@ -573,7 +567,8 @@ type endpoint struct { // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. - accepted accepted + // +checklocks:acceptMu + acceptQueue acceptQueue // The following are only used from the protocol goroutine, and // therefore don't need locks to protect them. @@ -606,8 +601,7 @@ type endpoint struct { gso stack.GSO - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats Stats `state:"nosave"` + stats Stats // tcpLingerTimeout is the maximum amount of a time a socket // a socket stays in TIME_WAIT state before being marked @@ -819,10 +813,9 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto waiterQueue: waiterQueue, state: uint32(StateInitial), keepalive: keepalive{ - // Linux defaults. - idle: 2 * time.Hour, - interval: 75 * time.Second, - count: 9, + idle: DefaultKeepaliveIdle, + interval: DefaultKeepaliveInterval, + count: DefaultKeepaliveCount, }, uniqueID: s.UniqueID(), txHash: s.Rand().Uint32(), @@ -904,7 +897,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Check if there's anything in the accepted queue. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() - if e.accepted.endpoints.Len() != 0 { + if e.acceptQueue.endpoints.Len() != 0 { result |= waiter.ReadableEvents } e.acceptMu.Unlock() @@ -1087,20 +1080,20 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - acceptedCopy := e.accepted - e.accepted = accepted{} - e.acceptMu.Unlock() - - if acceptedCopy == (accepted{}) { - return + // Close any endpoints in SYN-RCVD state. + for n := range e.acceptQueue.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) } - - e.acceptCond.Broadcast() - + e.acceptQueue.pendingEndpoints = nil // Reset all connections that are waiting to be accepted. - for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { + for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() { n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) } + e.acceptQueue.endpoints.Init() + e.acceptMu.Unlock() + + e.acceptCond.Broadcast() + // Wait for reset of all endpoints that are still waiting to be delivered to // the now closed accepted. e.pendingAccepted.Wait() @@ -2060,7 +2053,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber) e.UnlockUser() if err != nil { return err @@ -2380,6 +2373,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() + + if e.EndpointState().connecting() { + // When calling shutdown(2) on a connecting socket, the endpoint must + // enter the error state. But this logic cannot belong to the shutdownLocked + // method because that method is called during a close(2) (and closing a + // connecting socket is not an error). + e.resetConnectionLocked(&tcpip.ErrConnectionReset{}) + e.notifyProtocolGoroutine(notifyShutdown) + e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) + return nil + } + return e.shutdownLocked(flags) } @@ -2480,22 +2485,23 @@ func (e *endpoint) listen(backlog int) tcpip.Error { if e.EndpointState() == StateListen && !e.closed { e.acceptMu.Lock() defer e.acceptMu.Unlock() - if e.accepted == (accepted{}) { - // listen is called after shutdown. - e.accepted.cap = backlog - e.shutdownFlags = 0 - e.rcvQueueInfo.rcvQueueMu.Lock() - e.rcvQueueInfo.RcvClosed = false - e.rcvQueueInfo.rcvQueueMu.Unlock() - } else { - // Adjust the size of the backlog iff we can fit - // existing pending connections into the new one. - if e.accepted.endpoints.Len() > backlog { - return &tcpip.ErrInvalidEndpointState{} - } - e.accepted.cap = backlog + + // Adjust the size of the backlog iff we can fit + // existing pending connections into the new one. + if e.acceptQueue.endpoints.Len() > backlog { + return &tcpip.ErrInvalidEndpointState{} + } + e.acceptQueue.capacity = backlog + + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) } + e.shutdownFlags = 0 + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = false + e.rcvQueueInfo.rcvQueueMu.Unlock() + // Notify any blocked goroutines that they can attempt to // deliver endpoints again. e.acceptCond.Broadcast() @@ -2530,8 +2536,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() - if e.accepted == (accepted{}) { - e.accepted.cap = backlog + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) + } + if e.acceptQueue.capacity == 0 { + e.acceptQueue.capacity = backlog } e.acceptMu.Unlock() @@ -2571,8 +2580,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. // Get the new accepted endpoint. var n *endpoint e.acceptMu.Lock() - if element := e.accepted.endpoints.Front(); element != nil { - n = e.accepted.endpoints.Remove(element).(*endpoint) + if element := e.acceptQueue.endpoints.Front(); element != nil { + n = e.acceptQueue.endpoints.Remove(element).(*endpoint) } e.acceptMu.Unlock() if n == nil { @@ -2989,6 +2998,8 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { } s.Sender.RACKState = e.snd.rc.TCPRACKState + s.Sender.RetransmitTS = e.snd.retransmitTS + s.Sender.SpuriousRecovery = e.snd.spuriousRecovery return s } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index f2e8b3840..94072a115 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() { } // saveEndpoints is invoked by stateify. -func (a *accepted) saveEndpoints() []*endpoint { +func (a *acceptQueue) saveEndpoints() []*endpoint { acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { acceptedEndpoints[i] = e.Value.(*endpoint) @@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint { } // loadEndpoints is invoked by stateify. -func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { +func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) { for _, ep := range acceptedEndpoints { a.endpoints.PushBack(ep) } @@ -251,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() - backlog := e.accepted.cap + e.acceptMu.Lock() + backlog := e.acceptQueue.capacity + e.acceptMu.Unlock() if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index e4410ad93..f122ea009 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -66,6 +66,18 @@ const ( // DefaultSynRetries is the default value for the number of SYN retransmits // before a connect is aborted. DefaultSynRetries = 6 + + // DefaultKeepaliveIdle is the idle time for a connection before keep-alive + // probes are sent. + DefaultKeepaliveIdle = 2 * time.Hour + + // DefaultKeepaliveInterval is the time between two successive keep-alive + // probes. + DefaultKeepaliveInterval = 75 * time.Second + + // DefaultKeepaliveCount is the number of keep-alive probes that are sent + // before declaring the connection dead. + DefaultKeepaliveCount = 9 ) const ( diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index 8a026ec46..e47a07030 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rcv_test +package tcp_test import ( "testing" diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 2e6ea06f5..2d5fdda19 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW DataSize: seg.data.Size(), SegMemSize: seg.segMemSize(), } - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("%s differs (-want +got):\n%s", name, diff) } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 2fabf1594..4377f07a0 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -144,6 +144,15 @@ type sender struct { // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm. probeTimer timer `state:"nosave"` probeWaker sleep.Waker `state:"nosave"` + + // spuriousRecovery indicates whether the sender entered recovery + // spuriously as described in RFC3522 Section 3.2. + spuriousRecovery bool + + // retransmitTS is the timestamp at which the sender sends retransmitted + // segment after entering an RTO for the first time as described in + // RFC3522 Section 3.2. + retransmitTS uint32 } // rtt is a synchronization wrapper used to appease stateify. See the comment @@ -425,6 +434,13 @@ func (s *sender) retransmitTimerExpired() bool { return true } + // Initialize the variables used to detect spurious recovery after + // entering RTO. + // + // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1. + s.spuriousRecovery = false + s.retransmitTS = 0 + // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases // when writeList is empty. Remove this once we have a proper fix for this // issue. @@ -495,6 +511,10 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveRecovery() } + // Record retransmitTS if the sender is not in recovery as per: + // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + s.recordRetransmitTS() + s.state = tcpip.RTORecovery s.cc.HandleRTOExpired() @@ -958,6 +978,13 @@ func (s *sender) sendData() { } func (s *sender) enterRecovery() { + // Initialize the variables used to detect spurious recovery after + // entering recovery. + // + // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1. + s.spuriousRecovery = false + s.retransmitTS = 0 + s.FastRecovery.Active = true // Save state to reflect we're now in fast recovery. // @@ -972,6 +999,11 @@ func (s *sender) enterRecovery() { s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding s.FastRecovery.HighRxt = s.SndUna s.FastRecovery.RescueRxt = s.SndUna + + // Record retransmitTS if the sender is not in recovery as per: + // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + s.recordRetransmitTS() + if s.ep.SACKPermitted { s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() @@ -1147,13 +1179,15 @@ func (s *sender) isDupAck(seg *segment) bool { // Iterate the writeList and update RACK for each segment which is newly acked // either cumulatively or selectively. Loop through the segments which are // sacked, and update the RACK related variables and check for reordering. +// Returns true when the DSACK block has been detected in the received ACK. // // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. -func (s *sender) walkSACK(rcvdSeg *segment) { +func (s *sender) walkSACK(rcvdSeg *segment) bool { s.rc.setDSACKSeen(false) // Look for DSACK block. + hasDSACK := false idx := 0 n := len(rcvdSeg.parsedOptions.SACKBlocks) if checkDSACK(rcvdSeg) { @@ -1167,10 +1201,11 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.setDSACKSeen(true) idx = 1 n-- + hasDSACK = true } if n == 0 { - return + return hasDSACK } // Sort the SACK blocks. The first block is the most recent unacked @@ -1193,6 +1228,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { seg = seg.Next() } } + return hasDSACK } // checkDSACK checks if a DSACK is reported. @@ -1239,6 +1275,85 @@ func checkDSACK(rcvdSeg *segment) bool { return false } +func (s *sender) recordRetransmitTS() { + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 + // + // The Eifel detection algorithm is used, only upon initiation of loss + // recovery, i.e., when either the timeout-based retransmit or the fast + // retransmit is sent. The Eifel detection algorithm MUST NOT be + // reinitiated after loss recovery has already started. In particular, + // it must not be reinitiated upon subsequent timeouts for the same + // segment, and not upon retransmitting segments other than the oldest + // outstanding segment, e.g., during selective loss recovery. + if s.inRecovery() { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + // + // Set a "RetransmitTS" variable to the value of the Timestamp Value + // field of the Timestamps option included in the retransmit sent when + // loss recovery is initiated. A TCP sender must ensure that + // RetransmitTS does not get overwritten as loss recovery progresses, + // e.g., in case of a second timeout and subsequent second retransmit of + // the same octet. + s.retransmitTS = s.ep.tsValNow() +} + +func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) { + // Return if the sender has already detected spurious recovery. + if s.spuriousRecovery { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 4 + // + // If the value of the Timestamp Echo Reply field of the acceptable ACK's + // Timestamps option is smaller than the value of RetransmitTS, then + // proceed to next step, else return. + if tsEchoReply >= s.retransmitTS { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5 + // + // If the acceptable ACK carries a DSACK option [RFC2883], then return. + if hasDSACK { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5 + // + // If during the lifetime of the TCP connection the TCP sender has + // previously received an ACK with a DSACK option, or the acceptable ACK + // does not acknowledge all outstanding data, then proceed to next step, + // else return. + numDSACK := s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.Value() + if numDSACK == 0 && s.SndUna == s.SndNxt { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 6 + // + // If the loss recovery has been initiated with a timeout-based + // retransmit, then set + // SpuriousRecovery <- SPUR_TO (equal 1), + // else set + // SpuriousRecovery <- dupacks+1 + // Set the spurious recovery variable to true as we do not differentiate + // between fast, SACK or RTO recovery. + s.spuriousRecovery = true + s.ep.stack.Stats().TCP.SpuriousRecovery.Increment() +} + +// Check if the sender is in RTORecovery, FastRecovery or SACKRecovery state. +func (s *sender) inRecovery() bool { + if s.state == tcpip.RTORecovery || s.state == tcpip.FastRecovery || s.state == tcpip.SACKRecovery { + return true + } + return false +} + // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { @@ -1254,6 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Insert SACKBlock information into our scoreboard. + hasDSACK := false if s.ep.SACKPermitted { for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds @@ -1288,7 +1404,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // RACK.fack, then the corresponding packet has been // reordered and RACK.reord is set to TRUE. if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { - s.walkSACK(rcvdSeg) + hasDSACK = s.walkSACK(rcvdSeg) } s.SetPipe() } @@ -1418,6 +1534,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Clear SACK information for all acked data. s.ep.scoreboard.Delete(s.SndUna) + // Detect if the sender entered recovery spuriously. + if s.inRecovery() { + s.detectSpuriousRecovery(hasDSACK, rcvdSeg.parsedOptions.TSEcr) + } + // If we are not in fast recovery then update the congestion // window based on the number of acknowledged packets. if !s.FastRecovery.Active { diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index c35db7c95..0d36d0dd0 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -1059,16 +1059,17 @@ func TestRACKWithWindowFull(t *testing.T) { for i := 0; i < numPkts; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload - if i == 0 { - // Send ACK for the first packet to establish RTT. - c.SendAck(seq, maxPayload) - } } - // SACK for #10 packet. - start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) + // Expect retransmission of last packet due to TLP. + c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, tsOptionSize) + + // SACK for first and last packet. + start := c.IRS.Add(seqnum.Size(maxPayload)) end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}}) + dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) var info tcpip.TCPInfoOption if err := c.EP.GetSockOpt(&info); err != nil { diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 6255355bb..896249d2d 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -23,6 +23,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -702,3 +703,257 @@ func TestRecoveryEntry(t *testing.T) { t.Error(err) } } + +func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery uint64) { + t.Helper() + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil + } + + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } +} + +func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b, data []byte) { + payloadLen := uint32(len(tcpHdr.Payload())) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead), + checker.TCPAckNum(context.TestInitialSequenceNumber+1), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + pdata := data[bytesRead : bytesRead+payloadLen] + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { + t.Fatalf("got data = %v, want = %v", p, pdata) + } +} + +func buildTSOptionFromHeader(tcpHdr header.TCP) []byte { + parsedOpts := tcpHdr.ParsedOptions() + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) + return tsOpt[:] +} + +func TestDetectSpuriousRecoveryWithRTO(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { + if s.Sender.RetransmitTS == 0 { + t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + } + if !s.Sender.SpuriousRecovery { + t.Fatalf("Spurious recovery was not detected") + } + close(probeDone) + }) + + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := make([]byte, numPackets*maxPayload) + for i := range data { + data[i] = byte(i) + } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + var options []byte + var bytesRead uint32 + for i := 0; i < numPackets; i++ { + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) + + // Get options only for the first packet. This will be sent with + // the ACK to indicate the acknowledgement is for the original + // packet. + if i == 0 && c.TimeStampEnabled { + options = buildTSOptionFromHeader(tcpHdr) + } + bytesRead += uint32(len(tcpHdr.Payload())) + } + + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Expect #5 segment with TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + // Expect #1 segment because of RTO. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + info := tcpip.TCPInfoOption{} + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + + if info.CcState != tcpip.RTORecovery { + t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery) + } + + // Acknowledge the data. + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone + + verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */) +} + +func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + numAck := 0 + probeDone := make(chan struct{}) + c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { + if numAck < 3 { + numAck++ + return + } + + if s.Sender.RetransmitTS == 0 { + t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + } + if !s.Sender.SpuriousRecovery { + t.Fatalf("Spurious recovery was not detected") + } + close(probeDone) + }) + + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := make([]byte, numPackets*maxPayload) + for i := range data { + data[i] = byte(i) + } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + var options []byte + var bytesRead uint32 + for i := 0; i < numPackets; i++ { + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) + + // Get options only for the first packet. This will be sent with + // the ACK to indicate the acknowledgement is for the original + // packet. + if i == 0 && c.TimeStampEnabled { + options = buildTSOptionFromHeader(tcpHdr) + } + bytesRead += uint32(len(tcpHdr.Payload())) + } + + // Receive the retransmitted packet after TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Send ACK for #3 and #4 segments to avoid entering TLP. + start := c.IRS.Add(3*maxPayload + 1) + end := start.Add(2 * maxPayload) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + + c.SendAck(seq, 0 /* bytesReceived */) + c.SendAck(seq, 0 /* bytesReceived */) + + // Receive the retransmitted packet after three duplicate ACKs. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + info := tcpip.TCPInfoOption{} + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + + if info.CcState != tcpip.SACKRecovery { + t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery) + } + + // Acknowledge the data. + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone + + verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */) +} + +func TestNoSpuriousRecoveryWithDSACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) + + // Receive the retransmitted packet after TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + // Send ACK for #3 and #4 segments to avoid entering TLP. + start := c.IRS.Add(3*maxPayload + 1) + end := start.Add(2 * maxPayload) + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + + c.SendAck(seq, 0 /* bytesReceived */) + c.SendAck(seq, 0 /* bytesReceived */) + + // Receive the retransmitted packet after three duplicate ACKs. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + // Acknowledge the data with DSACK for #1 segment. + start = c.IRS.Add(maxPayload + 1) + end = start.Add(2 * maxPayload) + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}}) + + verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index bc8708a5b..6f1ee3816 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1382,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) { if err := s.CreateNIC(id, ep); err != nil { t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) } - if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { - t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ {Destination: header.IPv4EmptySubnet, NIC: id}, @@ -1652,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) { } } +func TestShutdownConnectingSocket(t *testing.T) { + for _, test := range []struct { + name string + shutdownMode tcpip.ShutdownFlags + }{ + {"ShutdownRead", tcpip.ShutdownRead}, + {"ShutdownWrite", tcpip.ShutdownWrite}, + {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with + // the handshake process. + c.Create(-1) + + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + // Start connection attempt. + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) + } + + // Check the SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + if err := c.EP.Shutdown(test.shutdownMode); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + // The endpoint internal state is updated immediately. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + select { + case <-ch: + default: + t.Fatal("endpoint was not notified") + } + + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrConnectionReset{}) + + // If the endpoint is not properly shutdown, it'll re-attempt to connect + // by sending another ACK packet. + c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond)) + }) + } +} + func TestSynSent(t *testing.T) { for _, test := range []struct { name string @@ -1675,7 +1744,7 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } @@ -1991,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { ) // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the FIN but DON't ACK IT. checker.IPv4(t, c.GetPacket(), @@ -2007,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // Cause a RST to be generated by closing the read end now since we have // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the RST checker.IPv4(t, c.GetPacket(), @@ -2145,12 +2218,15 @@ func TestSmallReceiveBufferReadiness(t *testing.T) { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err) } - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x7f\x00\x00\x01"), - PrefixLen: 8, + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x7f\x00\x00\x01"), + PrefixLen: 8, + }, } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err) } { @@ -4954,13 +5030,17 @@ func makeStack() (*stack.Stack, tcpip.Error) { } for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address + number tcpip.NetworkProtocolNumber + addrWithPrefix tcpip.AddressWithPrefix }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, + {ipv4.ProtocolNumber, context.StackAddrWithPrefix}, + {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix}, } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ct.number, + AddressWithPrefix: ct.addrWithPrefix, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { return nil, err } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6e55a7a32..88bb99354 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -243,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context { Protocol: ipv4.ProtocolNumber, AddressWithPrefix: StackAddrWithPrefix, } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err) } routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, @@ -257,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context { Protocol: ipv6.ProtocolNumber, AddressWithPrefix: StackV6AddrWithPrefix, } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err) } routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 5cc7a2886..d2c0963b0 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -63,5 +63,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4255457f9..077a2325a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -60,9 +60,8 @@ type endpoint struct { waiterQueue *waiter.Queue uniqueID uint64 net network.Endpoint - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats tcpip.TransportEndpointStats `state:"nosave"` - ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -234,7 +233,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.receivedAt.UnixNano(), + Timestamp: p.receivedAt, } switch p.netProto { @@ -243,19 +242,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult cm.HasTOS = true cm.TOS = p.tos } + + if e.ops.GetReceivePacketInfo() { + cm.HasIPPacketInfo = true + cm.PacketInfo = p.packetInfo + } case header.IPv6ProtocolNumber: if e.ops.GetReceiveTClass() { cm.HasTClass = true // Although TClass is an 8-bit value it's read in the CMsg as a uint32. cm.TClass = uint32(p.tos) } + + if e.ops.GetIPv6ReceivePacketInfo() { + cm.HasIPv6PacketInfo = true + cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{ + NIC: p.packetInfo.NIC, + Addr: p.packetInfo.DestinationAddr, + } + } default: panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto)) } - if e.ops.GetReceivePacketInfo() { - cm.HasIPPacketInfo = true - cm.PacketInfo = p.packetInfo - } + if e.ops.GetReceiveOriginalDstAddress() { cm.HasOriginalDstAddress = true cm.OriginalDstAddress = p.destinationAddress @@ -283,7 +292,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -// +checklocks:e.mu +// +checklocksread:e.mu func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.net.State() { case transport.DatagramEndpointStateInitial: diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 554ce1de4..b3199489c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -313,6 +314,9 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo Clock: &faketime.NullClock{}, } s := stack.New(options) + // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus + // never allows ICMP messages. + s.SetICMPLimit(rate.Inf) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -323,12 +327,20 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress((%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, stackV6Addr, err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) } s.SetRouteTable([]tcpip.Route{ @@ -1357,64 +1369,70 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { func TestReadIPPacketInfo(t *testing.T) { tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedLocalAddr tcpip.Address - expectedDestAddr tcpip.Address + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + checker func(tcpip.NICID) checker.ControlMessagesChecker }{ { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedLocalAddr: stackAddr, - expectedDestAddr: stackAddr, + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + LocalAddr: stackAddr, + DestinationAddr: stackAddr, + }) + }, }, { name: "IPv4 multicast", proto: header.IPv4ProtocolNumber, flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastAddr, - expectedDestAddr: multicastAddr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + // TODO(gvisor.dev/issue/3556): Check for a unicast address. + LocalAddr: multicastAddr, + DestinationAddr: multicastAddr, + }) + }, }, { name: "IPv4 broadcast", proto: header.IPv4ProtocolNumber, flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: broadcastAddr, - expectedDestAddr: broadcastAddr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + // TODO(gvisor.dev/issue/3556): Check for a unicast address. + LocalAddr: broadcastAddr, + DestinationAddr: broadcastAddr, + }) + }, }, { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedLocalAddr: stackV6Addr, - expectedDestAddr: stackV6Addr, + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{ + NIC: id, + Addr: stackV6Addr, + }) + }, }, { name: "IPv6 multicast", proto: header.IPv6ProtocolNumber, flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastV6Addr, - expectedDestAddr: multicastV6Addr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{ + NIC: id, + Addr: multicastV6Addr, + }) + }, }, } @@ -1437,13 +1455,16 @@ func TestReadIPPacketInfo(t *testing.T) { } } - c.ep.SocketOptions().SetReceivePacketInfo(true) + switch f := test.flow.netProto(); f { + case header.IPv4ProtocolNumber: + c.ep.SocketOptions().SetReceivePacketInfo(true) + case header.IPv6ProtocolNumber: + c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true) + default: + t.Fatalf("unhandled protocol number = %d", f) + } - testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ - NIC: 1, - LocalAddr: test.expectedLocalAddr, - DestinationAddr: test.expectedDestAddr, - })) + testRead(c, test.flow, test.checker(c.nicID)) if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) @@ -2504,8 +2525,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index 234125c38..8902be2d3 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/eventfd", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go index 40fa72925..0dc0c37bd 100644 --- a/pkg/unet/unet.go +++ b/pkg/unet/unet.go @@ -23,6 +23,7 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/sync" ) @@ -55,15 +56,6 @@ func socket(packet bool) (int, error) { return fd, nil } -// eventFD returns a new event FD with initial value 0. -func eventFD() (int, error) { - f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0) - if e != 0 { - return -1, e - } - return int(f), nil -} - // Socket is a connected unix domain socket. type Socket struct { // gate protects use of fd. @@ -78,7 +70,7 @@ type Socket struct { // efd is an event FD that is signaled when the socket is closing. // // efd is immutable and remains valid until Close/Release. - efd int + efd eventfd.Eventfd // race is an atomic variable used to avoid triggering the race // detector. See comment in SocketPair below. @@ -95,7 +87,7 @@ func NewSocket(fd int) (*Socket, error) { return nil, err } - efd, err := eventFD() + efd, err := eventfd.Create() if err != nil { return nil, err } @@ -110,16 +102,14 @@ func NewSocket(fd int) (*Socket, error) { // closing the event FD. func (s *Socket) finish() error { // Signal any blocked or future polls. - // - // N.B. eventfd writes must be 8 bytes. - if _, err := unix.Write(s.efd, []byte{1, 0, 0, 0, 0, 0, 0, 0}); err != nil { + if err := s.efd.Notify(); err != nil { return err } // Close the gate, blocking until all FD users leave. s.gate.Close() - return unix.Close(s.efd) + return s.efd.Close() } // Close closes the socket. diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go index f0bf93ddd..ea281fec3 100644 --- a/pkg/unet/unet_unsafe.go +++ b/pkg/unet/unet_unsafe.go @@ -43,7 +43,7 @@ func (s *Socket) wait(write bool) error { }, { // The eventfd, signaled when we are closing. - Fd: int32(s.efd), + Fd: int32(s.efd.FD()), Events: unix.POLLIN, }, } |