diff options
Diffstat (limited to 'pkg/flipcall')
-rw-r--r-- | pkg/flipcall/BUILD | 33 | ||||
-rw-r--r-- | pkg/flipcall/ctrl_futex.go | 176 | ||||
-rw-r--r-- | pkg/flipcall/flipcall.go | 257 | ||||
-rw-r--r-- | pkg/flipcall/flipcall_example_test.go | 113 | ||||
-rw-r--r-- | pkg/flipcall/flipcall_test.go | 405 | ||||
-rw-r--r-- | pkg/flipcall/flipcall_unsafe.go | 87 | ||||
-rw-r--r-- | pkg/flipcall/futex_linux.go | 118 | ||||
-rw-r--r-- | pkg/flipcall/io.go | 113 | ||||
-rw-r--r-- | pkg/flipcall/packet_window_allocator.go | 166 |
9 files changed, 1468 insertions, 0 deletions
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD new file mode 100644 index 000000000..9c5ad500b --- /dev/null +++ b/pkg/flipcall/BUILD @@ -0,0 +1,33 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +licenses(["notice"]) + +go_library( + name = "flipcall", + srcs = [ + "ctrl_futex.go", + "flipcall.go", + "flipcall_unsafe.go", + "futex_linux.go", + "io.go", + "packet_window_allocator.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux", + "//pkg/log", + "//pkg/memutil", + "//pkg/sync", + ], +) + +go_test( + name = "flipcall_test", + size = "small", + srcs = [ + "flipcall_example_test.go", + "flipcall_test.go", + ], + library = ":flipcall", + deps = ["//pkg/sync"], +) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go new file mode 100644 index 000000000..e7c3a3a0b --- /dev/null +++ b/pkg/flipcall/ctrl_futex.go @@ -0,0 +1,176 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "encoding/json" + "fmt" + "math" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/log" +) + +type endpointControlImpl struct { + state int32 +} + +// Bits in endpointControlImpl.state. +const ( + epsBlocked = 1 << iota + epsShutdown +) + +func (ep *Endpoint) ctrlInit(opts ...EndpointOption) error { + if len(opts) != 0 { + return fmt.Errorf("unknown EndpointOption: %T", opts[0]) + } + return nil +} + +type ctrlHandshakeRequest struct{} + +type ctrlHandshakeResponse struct{} + +func (ep *Endpoint) ctrlConnect() error { + if err := ep.enterFutexWait(); err != nil { + return err + } + _, err := ep.futexConnect(&ctrlHandshakeRequest{}) + ep.exitFutexWait() + return err +} + +func (ep *Endpoint) ctrlWaitFirst() error { + if err := ep.enterFutexWait(); err != nil { + return err + } + defer ep.exitFutexWait() + + // Wait for the handshake request. + if err := ep.futexSwitchFromPeer(); err != nil { + return err + } + + // Read the handshake request. + reqLen := atomic.LoadUint32(ep.dataLen()) + if reqLen > ep.dataCap { + return fmt.Errorf("invalid handshake request length %d (maximum %d)", reqLen, ep.dataCap) + } + var req ctrlHandshakeRequest + if err := json.NewDecoder(ep.NewReader(reqLen)).Decode(&req); err != nil { + return fmt.Errorf("error reading handshake request: %v", err) + } + + // Write the handshake response. + w := ep.NewWriter() + if err := json.NewEncoder(w).Encode(ctrlHandshakeResponse{}); err != nil { + return fmt.Errorf("error writing handshake response: %v", err) + } + *ep.dataLen() = w.Len() + + // Return control to the client. + raceBecomeInactive() + if err := ep.futexSwitchToPeer(); err != nil { + return err + } + + // Wait for the first non-handshake message. + return ep.futexSwitchFromPeer() +} + +func (ep *Endpoint) ctrlRoundTrip() error { + if err := ep.futexSwitchToPeer(); err != nil { + return err + } + if err := ep.enterFutexWait(); err != nil { + return err + } + err := ep.futexSwitchFromPeer() + ep.exitFutexWait() + return err +} + +func (ep *Endpoint) ctrlWakeLast() error { + return ep.futexSwitchToPeer() +} + +func (ep *Endpoint) enterFutexWait() error { + switch eps := atomic.AddInt32(&ep.ctrl.state, epsBlocked); eps { + case epsBlocked: + return nil + case epsBlocked | epsShutdown: + atomic.AddInt32(&ep.ctrl.state, -epsBlocked) + return ShutdownError{} + default: + // Most likely due to ep.enterFutexWait() being called concurrently + // from multiple goroutines. + panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state before flipcall.Endpoint.enterFutexWait(): %v", eps-epsBlocked)) + } +} + +func (ep *Endpoint) exitFutexWait() { + switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps { + case 0: + return + case epsShutdown: + // ep.ctrlShutdown() was called while we were blocked, so we are + // repsonsible for indicating connection shutdown. + ep.shutdownConn() + default: + panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked)) + } +} + +func (ep *Endpoint) ctrlShutdown() { + // Set epsShutdown to ensure that future calls to ep.enterFutexWait() fail. + if atomic.AddInt32(&ep.ctrl.state, epsShutdown)&epsBlocked != 0 { + // Wake the blocked thread. This must loop because it's possible that + // FUTEX_WAKE occurs after the waiter sets epsBlocked, but before it + // blocks in FUTEX_WAIT. + for { + // Wake MaxInt32 threads to prevent a broken or malicious peer from + // swallowing our wakeup by FUTEX_WAITing from multiple threads. + if err := ep.futexWakeConnState(math.MaxInt32); err != nil { + log.Warningf("failed to FUTEX_WAKE Endpoints: %v", err) + break + } + yieldThread() + if atomic.LoadInt32(&ep.ctrl.state)&epsBlocked == 0 { + break + } + } + } else { + // There is no blocked thread, so we are responsible for indicating + // connection shutdown. + ep.shutdownConn() + } +} + +func (ep *Endpoint) shutdownConn() { + switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs { + case ep.activeState: + if err := ep.futexWakeConnState(1); err != nil { + log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err) + } + case ep.inactiveState: + // The peer is currently active and will detect shutdown when it tries + // to update the connection state. + case csShutdown: + // The peer also called Endpoint.Shutdown(). + default: + log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs) + } +} diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go new file mode 100644 index 000000000..3cdb576e1 --- /dev/null +++ b/pkg/flipcall/flipcall.go @@ -0,0 +1,257 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package flipcall implements a protocol providing Fast Local Interprocess +// Procedure Calls between mutually-distrusting processes. +package flipcall + +import ( + "fmt" + "math" + "sync/atomic" + "syscall" +) + +// An Endpoint provides the ability to synchronously transfer data and control +// to a connected peer Endpoint, which may be in another process. +// +// Since the Endpoint control transfer model is synchronous, at any given time +// one Endpoint "has control" (designated the active Endpoint), and the other +// is "waiting for control" (designated the inactive Endpoint). Users of the +// flipcall package designate one Endpoint as the client, which is initially +// active, and the other as the server, which is initially inactive. See +// flipcall_example_test.go for usage. +type Endpoint struct { + // packet is a pointer to the beginning of the packet window. (Since this + // is a raw OS memory mapping and not a Go object, it does not need to be + // represented as an unsafe.Pointer.) packet is immutable. + packet uintptr + + // dataCap is the size of the datagram part of the packet window in bytes. + // dataCap is immutable. + dataCap uint32 + + // activeState is csClientActive if this is a client Endpoint and + // csServerActive if this is a server Endpoint. + activeState uint32 + + // inactiveState is csServerActive if this is a client Endpoint and + // csClientActive if this is a server Endpoint. + inactiveState uint32 + + // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the + // Endpoint has acknowledged shutdown initiated by the peer. shutdown is + // accessed using atomic memory operations. + shutdown uint32 + + ctrl endpointControlImpl +} + +// EndpointSide indicates which side of a connection an Endpoint belongs to. +type EndpointSide int + +const ( + // ClientSide indicates that an Endpoint is a client (initially-active; + // first method call should be Connect). + ClientSide EndpointSide = iota + + // ServerSide indicates that an Endpoint is a server (initially-inactive; + // first method call should be RecvFirst.) + ServerSide +) + +// Init must be called on zero-value Endpoints before first use. If it +// succeeds, ep.Destroy() must be called once the Endpoint is no longer in use. +// +// pwd represents the packet window used to exchange data with the peer +// Endpoint. FD may differ between Endpoints if they are in different +// processes, but must represent the same file. The packet window must +// initially be filled with zero bytes. +func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error { + switch side { + case ClientSide: + ep.activeState = csClientActive + ep.inactiveState = csServerActive + case ServerSide: + ep.activeState = csServerActive + ep.inactiveState = csClientActive + default: + return fmt.Errorf("invalid EndpointSide: %v", side) + } + if pwd.Length < pageSize { + return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize) + } + if pwd.Length > math.MaxUint32 { + return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32) + } + m, _, e := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) + if e != 0 { + return fmt.Errorf("failed to mmap packet window: %v", e) + } + ep.packet = m + ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes) + if err := ep.ctrlInit(opts...); err != nil { + ep.unmapPacket() + return err + } + return nil +} + +// NewEndpoint is a convenience function that returns an initialized Endpoint +// allocated on the heap. +func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) { + var ep Endpoint + if err := ep.Init(side, pwd, opts...); err != nil { + return nil, err + } + return &ep, nil +} + +// An EndpointOption configures an Endpoint. +type EndpointOption interface { + isEndpointOption() +} + +// Destroy releases resources owned by ep. No other Endpoint methods may be +// called after Destroy. +func (ep *Endpoint) Destroy() { + ep.unmapPacket() +} + +func (ep *Endpoint) unmapPacket() { + syscall.RawSyscall(syscall.SYS_MUNMAP, ep.packet, uintptr(ep.dataCap)+PacketHeaderBytes, 0) + ep.packet = 0 +} + +// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), +// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer +// Endpoint, to unblock and return ShutdownErrors. It does not wait for +// concurrent calls to return. Successive calls to Shutdown have no effect. +// +// Shutdown is the only Endpoint method that may be called concurrently with +// other methods on the same Endpoint. +func (ep *Endpoint) Shutdown() { + if atomic.SwapUint32(&ep.shutdown, 1) != 0 { + // ep.Shutdown() has previously been called. + return + } + ep.ctrlShutdown() +} + +// isShutdownLocally returns true if ep.Shutdown() has been called. +func (ep *Endpoint) isShutdownLocally() bool { + return atomic.LoadUint32(&ep.shutdown) != 0 +} + +// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown() +// has been called. +type ShutdownError struct{} + +// Error implements error.Error. +func (ShutdownError) Error() string { + return "flipcall connection shutdown" +} + +// DataCap returns the maximum datagram size supported by ep. Equivalently, +// DataCap returns len(ep.Data()). +func (ep *Endpoint) DataCap() uint32 { + return ep.dataCap +} + +// Connection state. +const ( + // The client is, by definition, initially active, so this must be 0. + csClientActive = 0 + csServerActive = 1 + csShutdown = 2 +) + +// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst(). +// +// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), +// ep.SendRecv(), and ep.SendLast() have never been called. +func (ep *Endpoint) Connect() error { + err := ep.ctrlConnect() + if err == nil { + raceBecomeActive() + } + return err +} + +// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then +// returns the datagram length specified by that call. +// +// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and +// ep.SendLast() have never been called. +func (ep *Endpoint) RecvFirst() (uint32, error) { + if err := ep.ctrlWaitFirst(); err != nil { + return 0, err + } + raceBecomeActive() + recvDataLen := atomic.LoadUint32(ep.dataLen()) + if recvDataLen > ep.dataCap { + return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) + } + return recvDataLen, nil +} + +// SendRecv transfers control to the peer Endpoint, causing its call to +// Endpoint.SendRecv() or Endpoint.RecvFirst() to return with the given +// datagram length, then blocks until the peer Endpoint calls +// Endpoint.SendRecv() or Endpoint.SendLast(). +// +// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or +// ep.RecvFirst() has returned an error. ep.SendLast() has never been called. +// If ep is a client Endpoint, ep.Connect() has previously been called and +// returned nil. +func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { + if dataLen > ep.dataCap { + panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) + } + // This store can safely be non-atomic: Under correct operation we should + // be the only thread writing ep.dataLen(), and ep.ctrlRoundTrip() will + // synchronize with the receiver. We will not read from ep.dataLen() until + // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then + // they can only shoot themselves in the foot. + *ep.dataLen() = dataLen + raceBecomeInactive() + if err := ep.ctrlRoundTrip(); err != nil { + return 0, err + } + raceBecomeActive() + recvDataLen := atomic.LoadUint32(ep.dataLen()) + if recvDataLen > ep.dataCap { + return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) + } + return recvDataLen, nil +} + +// SendLast causes the peer Endpoint's call to Endpoint.SendRecv() or +// Endpoint.RecvFirst() to return with the given datagram length. +// +// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or +// ep.RecvFirst() has returned an error. ep.SendLast() has never been called. +// If ep is a client Endpoint, ep.Connect() has previously been called and +// returned nil. +func (ep *Endpoint) SendLast(dataLen uint32) error { + if dataLen > ep.dataCap { + panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) + } + *ep.dataLen() = dataLen + raceBecomeInactive() + if err := ep.ctrlWakeLast(); err != nil { + return err + } + return nil +} diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go new file mode 100644 index 000000000..2e28a149a --- /dev/null +++ b/pkg/flipcall/flipcall_example_test.go @@ -0,0 +1,113 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/sync" +) + +func Example() { + const ( + reqPrefix = "request " + respPrefix = "response " + count = 3 + maxMessageLen = len(respPrefix) + 1 // 1 digit + ) + + pwa, err := NewPacketWindowAllocator() + if err != nil { + panic(err) + } + defer pwa.Destroy() + pwd, err := pwa.Allocate(PacketWindowLengthForDataCap(uint32(maxMessageLen))) + if err != nil { + panic(err) + } + var clientEP Endpoint + if err := clientEP.Init(ClientSide, pwd); err != nil { + panic(err) + } + defer clientEP.Destroy() + var serverEP Endpoint + if err := serverEP.Init(ServerSide, pwd); err != nil { + panic(err) + } + defer serverEP.Destroy() + + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + i := 0 + var buf bytes.Buffer + // wait for first request + n, err := serverEP.RecvFirst() + if err != nil { + return + } + for { + // read request + buf.Reset() + buf.Write(serverEP.Data()[:n]) + fmt.Println(buf.String()) + // write response + buf.Reset() + fmt.Fprintf(&buf, "%s%d", respPrefix, i) + copy(serverEP.Data(), buf.Bytes()) + // send response and wait for next request + n, err = serverEP.SendRecv(uint32(buf.Len())) + if err != nil { + return + } + i++ + } + }() + defer func() { + serverEP.Shutdown() + serverRun.Wait() + }() + + // establish connection as client + if err := clientEP.Connect(); err != nil { + panic(err) + } + var buf bytes.Buffer + for i := 0; i < count; i++ { + // write request + buf.Reset() + fmt.Fprintf(&buf, "%s%d", reqPrefix, i) + copy(clientEP.Data(), buf.Bytes()) + // send request and wait for response + n, err := clientEP.SendRecv(uint32(buf.Len())) + if err != nil { + panic(err) + } + // read response + buf.Reset() + buf.Write(clientEP.Data()[:n]) + fmt.Println(buf.String()) + } + + // Output: + // request 0 + // response 0 + // request 1 + // response 1 + // request 2 + // response 2 +} diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go new file mode 100644 index 000000000..33fd55a44 --- /dev/null +++ b/pkg/flipcall/flipcall_test.go @@ -0,0 +1,405 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "runtime" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +var testPacketWindowSize = pageSize + +type testConnection struct { + pwa PacketWindowAllocator + clientEP Endpoint + serverEP Endpoint +} + +func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []EndpointOption) *testConnection { + c := &testConnection{} + if err := c.pwa.Init(); err != nil { + tb.Fatalf("failed to create PacketWindowAllocator: %v", err) + } + pwd, err := c.pwa.Allocate(testPacketWindowSize) + if err != nil { + c.pwa.Destroy() + tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) + } + if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil { + c.pwa.Destroy() + tb.Fatalf("failed to create client Endpoint: %v", err) + } + if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil { + c.pwa.Destroy() + c.clientEP.Destroy() + tb.Fatalf("failed to create server Endpoint: %v", err) + } + return c +} + +func newTestConnection(tb testing.TB) *testConnection { + return newTestConnectionWithOptions(tb, nil, nil) +} + +func (c *testConnection) destroy() { + c.pwa.Destroy() + c.clientEP.Destroy() + c.serverEP.Destroy() +} + +func testSendRecv(t *testing.T, c *testConnection) { + // This shared variable is used to confirm that synchronization between + // flipcall endpoints is visible to the Go race detector. + state := 0 + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + t.Logf("server Endpoint waiting for packet 1") + if _, err := c.serverEP.RecvFirst(); err != nil { + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return + } + state++ + if state != 2 { + t.Errorf("shared state counter: got %d, wanted 2", state) + } + t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") + if _, err := c.serverEP.SendRecv(0); err != nil { + t.Errorf("server Endpoint.SendRecv() failed: %v", err) + return + } + state++ + if state != 4 { + t.Errorf("shared state counter: got %d, wanted 4", state) + } + t.Logf("server Endpoint got packet 3") + }() + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + + t.Logf("client Endpoint establishing connection") + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) + } + state++ + if state != 1 { + t.Errorf("shared state counter: got %d, wanted 1", state) + } + t.Logf("client Endpoint sending packet 1 and waiting for packet 2") + if _, err := c.clientEP.SendRecv(0); err != nil { + t.Fatalf("client Endpoint.SendRecv() failed: %v", err) + } + state++ + if state != 3 { + t.Errorf("shared state counter: got %d, wanted 3", state) + } + t.Logf("client Endpoint got packet 2, sending packet 3") + if err := c.clientEP.SendLast(0); err != nil { + t.Fatalf("client Endpoint.SendLast() failed: %v", err) + } + t.Logf("waiting for server goroutine to complete") + serverRun.Wait() +} + +func TestSendRecv(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testSendRecv(t, c) +} + +func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + if err := c.clientEP.Connect(); err == nil { + t.Errorf("client Endpoint.Connect() succeeded unexpectedly") + } +} + +func TestShutdownBeforeConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeConnect(t, c, false) +} + +func TestShutdownBeforeConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeConnect(t, c, true) +} + +func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) { + var clientRun sync.WaitGroup + clientRun.Add(1) + go func() { + defer clientRun.Done() + if err := c.clientEP.Connect(); err == nil { + t.Errorf("client Endpoint.Connect() succeeded unexpectedly") + } + }() + time.Sleep(time.Second) // to allow c.clientEP.Connect() to block + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + clientRun.Wait() +} + +func TestShutdownDuringConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringConnect(t, c, false) +} + +func TestShutdownDuringConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringConnect(t, c, true) +} + +func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) { + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } + if _, err := c.serverEP.RecvFirst(); err == nil { + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") + } +} + +func TestShutdownBeforeRecvFirstLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeRecvFirst(t, c, false) +} + +func TestShutdownBeforeRecvFirstRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeRecvFirst(t, c, true) +} + +func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err == nil { + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") + } + }() + time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } + serverRun.Wait() +} + +func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstBeforeConnect(t, c, false) +} + +func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstBeforeConnect(t, c, true) +} + +func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err == nil { + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") + } + }() + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) + } + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } + serverRun.Wait() +} + +func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstAfterConnect(t, c, false) +} + +func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstAfterConnect(t, c, true) +} + +func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err != nil { + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + } + // At this point, the client must be blocked in c.clientEP.SendRecv(). + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + }() + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) + } + if _, err := c.clientEP.SendRecv(0); err == nil { + t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly") + } +} + +func TestShutdownDuringClientSendRecvLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringClientSendRecv(t, c, false) +} + +func TestShutdownDuringClientSendRecvRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringClientSendRecv(t, c, true) +} + +func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err != nil { + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return + } + if _, err := c.serverEP.SendRecv(0); err == nil { + t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly") + } + }() + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) + } + if _, err := c.clientEP.SendRecv(0); err != nil { + t.Fatalf("client Endpoint.SendRecv() failed: %v", err) + } + time.Sleep(time.Second) // to allow serverEP.SendRecv() to block + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } + serverRun.Wait() +} + +func TestShutdownDuringServerSendRecvLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringServerSendRecv(t, c, false) +} + +func TestShutdownDuringServerSendRecvRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringServerSendRecv(t, c, true) +} + +func benchmarkSendRecv(b *testing.B, c *testConnection) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if b.N == 0 { + return + } + if _, err := c.serverEP.RecvFirst(); err != nil { + b.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return + } + for i := 1; i < b.N; i++ { + if _, err := c.serverEP.SendRecv(0); err != nil { + b.Errorf("server Endpoint.SendRecv() failed: %v", err) + return + } + } + if err := c.serverEP.SendLast(0); err != nil { + b.Errorf("server Endpoint.SendLast() failed: %v", err) + } + }() + defer func() { + c.serverEP.Shutdown() + serverRun.Wait() + }() + + if err := c.clientEP.Connect(); err != nil { + b.Fatalf("client Endpoint.Connect() failed: %v", err) + } + runtime.GC() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := c.clientEP.SendRecv(0); err != nil { + b.Fatalf("client Endpoint.SendRecv() failed: %v", err) + } + } + b.StopTimer() +} + +func BenchmarkSendRecv(b *testing.B) { + c := newTestConnection(b) + defer c.destroy() + benchmarkSendRecv(b, c) +} diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go new file mode 100644 index 000000000..ac974b232 --- /dev/null +++ b/pkg/flipcall/flipcall_unsafe.go @@ -0,0 +1,87 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "reflect" + "unsafe" + + "gvisor.dev/gvisor/pkg/sync" +) + +// Packets consist of a 16-byte header followed by an arbitrarily-sized +// datagram. The header consists of: +// +// - A 4-byte native-endian connection state. +// +// - A 4-byte native-endian datagram length in bytes. +// +// - 8 reserved bytes. +const ( + // PacketHeaderBytes is the size of a flipcall packet header in bytes. The + // maximum datagram size supported by a flipcall connection is equal to the + // length of the packet window minus PacketHeaderBytes. + // + // PacketHeaderBytes is exported to support its use in constant + // expressions. Non-constant expressions may prefer to use + // PacketWindowLengthForDataCap(). + PacketHeaderBytes = 16 +) + +func (ep *Endpoint) connState() *uint32 { + return (*uint32)((unsafe.Pointer)(ep.packet)) +} + +func (ep *Endpoint) dataLen() *uint32 { + return (*uint32)((unsafe.Pointer)(ep.packet + 4)) +} + +// Data returns the datagram part of ep's packet window as a byte slice. +// +// Note that the packet window is shared with the potentially-untrusted peer +// Endpoint, which may concurrently mutate the contents of the packet window. +// Thus: +// +// - Readers must not assume that two reads of the same byte in Data() will +// return the same result. In other words, readers should read any given byte +// in Data() at most once. +// +// - Writers must not assume that they will read back the same data that they +// have written. In other words, writers should avoid reading from Data() at +// all. +func (ep *Endpoint) Data() []byte { + var bs []byte + bsReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&bs)) + bsReflect.Data = ep.packet + PacketHeaderBytes + bsReflect.Len = int(ep.dataCap) + bsReflect.Cap = int(ep.dataCap) + return bs +} + +// ioSync is a dummy variable used to indicate synchronization to the Go race +// detector. Compare syscall.ioSync. +var ioSync int64 + +func raceBecomeActive() { + if sync.RaceEnabled { + sync.RaceAcquire((unsafe.Pointer)(&ioSync)) + } +} + +func raceBecomeInactive() { + if sync.RaceEnabled { + sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + } +} diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go new file mode 100644 index 000000000..168c1ccff --- /dev/null +++ b/pkg/flipcall/futex_linux.go @@ -0,0 +1,118 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package flipcall + +import ( + "encoding/json" + "fmt" + "runtime" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" +) + +func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeResponse, error) { + var resp ctrlHandshakeResponse + + // Write the handshake request. + w := ep.NewWriter() + if err := json.NewEncoder(w).Encode(req); err != nil { + return resp, fmt.Errorf("error writing handshake request: %v", err) + } + *ep.dataLen() = w.Len() + + // Exchange control with the server. + if err := ep.futexSwitchToPeer(); err != nil { + return resp, err + } + if err := ep.futexSwitchFromPeer(); err != nil { + return resp, err + } + + // Read the handshake response. + respLen := atomic.LoadUint32(ep.dataLen()) + if respLen > ep.dataCap { + return resp, fmt.Errorf("invalid handshake response length %d (maximum %d)", respLen, ep.dataCap) + } + if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil { + return resp, fmt.Errorf("error reading handshake response: %v", err) + } + + return resp, nil +} + +func (ep *Endpoint) futexSwitchToPeer() error { + // Update connection state to indicate that the peer should be active. + if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { + switch cs := atomic.LoadUint32(ep.connState()); cs { + case csShutdown: + return ShutdownError{} + default: + return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) + } + } + + // Wake the peer's Endpoint.futexSwitchFromPeer(). + if err := ep.futexWakeConnState(1); err != nil { + return fmt.Errorf("failed to FUTEX_WAKE peer Endpoint: %v", err) + } + return nil +} + +func (ep *Endpoint) futexSwitchFromPeer() error { + for { + switch cs := atomic.LoadUint32(ep.connState()); cs { + case ep.activeState: + return nil + case ep.inactiveState: + if ep.isShutdownLocally() { + return ShutdownError{} + } + if err := ep.futexWaitConnState(ep.inactiveState); err != nil { + return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) + } + continue + case csShutdown: + return ShutdownError{} + default: + return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) + } + } +} + +func (ep *Endpoint) futexWakeConnState(numThreads int32) error { + if _, _, e := syscall.RawSyscall(syscall.SYS_FUTEX, ep.packet, linux.FUTEX_WAKE, uintptr(numThreads)); e != 0 { + return e + } + return nil +} + +func (ep *Endpoint) futexWaitConnState(curState uint32) error { + _, _, e := syscall.Syscall6(syscall.SYS_FUTEX, ep.packet, linux.FUTEX_WAIT, uintptr(curState), 0, 0, 0) + if e != 0 && e != syscall.EAGAIN && e != syscall.EINTR { + return e + } + return nil +} + +func yieldThread() { + syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) + // The thread we're trying to yield to may be waiting for a Go runtime P. + // runtime.Gosched() will hand off ours if necessary. + runtime.Gosched() +} diff --git a/pkg/flipcall/io.go b/pkg/flipcall/io.go new file mode 100644 index 000000000..85e40b932 --- /dev/null +++ b/pkg/flipcall/io.go @@ -0,0 +1,113 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "fmt" + "io" +) + +// DatagramReader implements io.Reader by reading a datagram from an Endpoint's +// packet window. Its use is optional; users that can use Endpoint.Data() more +// efficiently are advised to do so. +type DatagramReader struct { + ep *Endpoint + off uint32 + end uint32 +} + +// Init must be called on zero-value DatagramReaders before first use. +// +// Preconditions: dataLen is 0, or was returned by a previous call to +// ep.RecvFirst() or ep.SendRecv(). +func (r *DatagramReader) Init(ep *Endpoint, dataLen uint32) { + r.ep = ep + r.Reset(dataLen) +} + +// Reset causes r to begin reading a new datagram of the given length from the +// associated Endpoint. +// +// Preconditions: dataLen is 0, or was returned by a previous call to the +// associated Endpoint's RecvFirst() or SendRecv() methods. +func (r *DatagramReader) Reset(dataLen uint32) { + if dataLen > r.ep.dataCap { + panic(fmt.Sprintf("invalid dataLen (%d) > ep.dataCap (%d)", dataLen, r.ep.dataCap)) + } + r.off = 0 + r.end = dataLen +} + +// NewReader is a convenience function that returns an initialized +// DatagramReader allocated on the heap. +// +// Preconditions: dataLen was returned by a previous call to ep.RecvFirst() or +// ep.SendRecv(). +func (ep *Endpoint) NewReader(dataLen uint32) *DatagramReader { + r := &DatagramReader{} + r.Init(ep, dataLen) + return r +} + +// Read implements io.Reader.Read. +func (r *DatagramReader) Read(dst []byte) (int, error) { + n := copy(dst, r.ep.Data()[r.off:r.end]) + r.off += uint32(n) + if r.off == r.end { + return n, io.EOF + } + return n, nil +} + +// DatagramWriter implements io.Writer by writing a datagram to an Endpoint's +// packet window. Its use is optional; users that can use Endpoint.Data() more +// efficiently are advised to do so. +type DatagramWriter struct { + ep *Endpoint + off uint32 +} + +// Init must be called on zero-value DatagramWriters before first use. +func (w *DatagramWriter) Init(ep *Endpoint) { + w.ep = ep +} + +// Reset causes w to begin writing a new datagram to the associated Endpoint. +func (w *DatagramWriter) Reset() { + w.off = 0 +} + +// NewWriter is a convenience function that returns an initialized +// DatagramWriter allocated on the heap. +func (ep *Endpoint) NewWriter() *DatagramWriter { + w := &DatagramWriter{} + w.Init(ep) + return w +} + +// Write implements io.Writer.Write. +func (w *DatagramWriter) Write(src []byte) (int, error) { + n := copy(w.ep.Data()[w.off:w.ep.dataCap], src) + w.off += uint32(n) + if n != len(src) { + return n, fmt.Errorf("datagram would exceed maximum size of %d bytes", w.ep.dataCap) + } + return n, nil +} + +// Len returns the length of the written datagram. +func (w *DatagramWriter) Len() uint32 { + return w.off +} diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window_allocator.go new file mode 100644 index 000000000..af9cc3d21 --- /dev/null +++ b/pkg/flipcall/packet_window_allocator.go @@ -0,0 +1,166 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "fmt" + "math/bits" + "os" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/memutil" +) + +var ( + pageSize = os.Getpagesize() + pageMask = pageSize - 1 +) + +func init() { + if bits.OnesCount(uint(pageSize)) != 1 { + // This is depended on by roundUpToPage(). + panic(fmt.Sprintf("system page size (%d) is not a power of 2", pageSize)) + } + if uintptr(pageSize) < PacketHeaderBytes { + // This is required since Endpoint.Init() imposes a minimum packet + // window size of 1 page. + panic(fmt.Sprintf("system page size (%d) is less than packet header size (%d)", pageSize, PacketHeaderBytes)) + } +} + +// PacketWindowDescriptor represents a packet window, a range of pages in a +// shared memory file that is used to exchange packets between partner +// Endpoints. +type PacketWindowDescriptor struct { + // FD is the file descriptor representing the shared memory file. + FD int + + // Offset is the offset into the shared memory file at which the packet + // window begins. + Offset int64 + + // Length is the size of the packet window in bytes. + Length int +} + +// PacketWindowLengthForDataCap returns the minimum packet window size required +// to accommodate datagrams of the given size in bytes. +func PacketWindowLengthForDataCap(dataCap uint32) int { + return roundUpToPage(int(dataCap) + int(PacketHeaderBytes)) +} + +func roundUpToPage(x int) int { + return (x + pageMask) &^ pageMask +} + +// A PacketWindowAllocator owns a shared memory file, and allocates packet +// windows from it. +type PacketWindowAllocator struct { + fd int + nextAlloc int64 + fileSize int64 +} + +// Init must be called on zero-value PacketWindowAllocators before first use. +// If it succeeds, Destroy() must be called once the PacketWindowAllocator is +// no longer in use. +func (pwa *PacketWindowAllocator) Init() error { + fd, err := memutil.CreateMemFD("flipcall_packet_windows", linux.MFD_CLOEXEC|linux.MFD_ALLOW_SEALING) + if err != nil { + return fmt.Errorf("failed to create memfd: %v", err) + } + // Apply F_SEAL_SHRINK to prevent either party from causing SIGBUS in the + // other by truncating the file, and F_SEAL_SEAL to prevent either party + // from applying F_SEAL_GROW or F_SEAL_WRITE. + if _, _, e := syscall.RawSyscall(syscall.SYS_FCNTL, uintptr(fd), linux.F_ADD_SEALS, linux.F_SEAL_SHRINK|linux.F_SEAL_SEAL); e != 0 { + syscall.Close(fd) + return fmt.Errorf("failed to apply memfd seals: %v", e) + } + pwa.fd = fd + return nil +} + +// NewPacketWindowAllocator is a convenience function that returns an +// initialized PacketWindowAllocator allocated on the heap. +func NewPacketWindowAllocator() (*PacketWindowAllocator, error) { + var pwa PacketWindowAllocator + if err := pwa.Init(); err != nil { + return nil, err + } + return &pwa, nil +} + +// Destroy releases resources owned by pwa. This invalidates file descriptors +// previously returned by pwa.FD() and pwd.Allocate(). +func (pwa *PacketWindowAllocator) Destroy() { + syscall.Close(pwa.fd) +} + +// FD represents the file descriptor of the shared memory file backing pwa. +func (pwa *PacketWindowAllocator) FD() int { + return pwa.fd +} + +// Allocate allocates a new packet window of at least the given size and +// returns a PacketWindowDescriptor representing it. +// +// Preconditions: size > 0. +func (pwa *PacketWindowAllocator) Allocate(size int) (PacketWindowDescriptor, error) { + if size <= 0 { + return PacketWindowDescriptor{}, fmt.Errorf("invalid size: %d", size) + } + // Page-align size to ensure that pwa.nextAlloc remains page-aligned. + size = roundUpToPage(size) + if size <= 0 { + return PacketWindowDescriptor{}, fmt.Errorf("size %d overflows after rounding up to page size", size) + } + end := pwa.nextAlloc + int64(size) // overflow checked by ensureFileSize + if err := pwa.ensureFileSize(end); err != nil { + return PacketWindowDescriptor{}, err + } + start := pwa.nextAlloc + pwa.nextAlloc = end + return PacketWindowDescriptor{ + FD: pwa.FD(), + Offset: start, + Length: size, + }, nil +} + +func (pwa *PacketWindowAllocator) ensureFileSize(min int64) error { + if min <= 0 { + return fmt.Errorf("file size would overflow") + } + if pwa.fileSize >= min { + return nil + } + newSize := 2 * pwa.fileSize + if newSize == 0 { + newSize = int64(pageSize) + } + for newSize < min { + newNewSize := newSize * 2 + if newNewSize <= 0 { + return fmt.Errorf("file size would overflow") + } + newSize = newNewSize + } + if err := syscall.Ftruncate(pwa.FD(), newSize); err != nil { + return fmt.Errorf("ftruncate failed: %v", err) + } + pwa.fileSize = newSize + return nil +} |