diff options
-rw-r--r-- | pkg/lisafs/client_file.go | 101 |
1 files changed, 77 insertions, 24 deletions
diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go index 0f8788f3b..170c15705 100644 --- a/pkg/lisafs/client_file.go +++ b/pkg/lisafs/client_file.go @@ -15,6 +15,8 @@ package lisafs import ( + "fmt" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -121,41 +123,92 @@ func (f *ClientFD) Sync(ctx context.Context) error { return err } +// chunkify applies fn to buf in chunks based on chunkSize. +func chunkify(chunkSize uint64, buf []byte, fn func([]byte, uint64) (uint64, error)) (uint64, error) { + toProcess := uint64(len(buf)) + var ( + totalProcessed uint64 + curProcessed uint64 + off uint64 + err error + ) + for { + if totalProcessed == toProcess { + return totalProcessed, nil + } + + if totalProcessed+chunkSize > toProcess { + curProcessed, err = fn(buf[totalProcessed:], off) + } else { + curProcessed, err = fn(buf[totalProcessed:totalProcessed+chunkSize], off) + } + totalProcessed += curProcessed + off += curProcessed + + if err != nil { + return totalProcessed, err + } + + // Return partial result immediately. + if curProcessed < chunkSize { + return totalProcessed, nil + } + + // If we received more bytes than we ever requested, this is a problem. + if totalProcessed > toProcess { + panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", totalProcessed, toProcess)) + } + } +} + // Read makes the PRead RPC. func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) { - req := PReadReq{ - Offset: offset, - FD: f.fd, - Count: uint32(len(dst)), - } + var resp PReadResp + // maxDataReadSize represents the maximum amount of data we can read at once + // (maximum message size - metadata size present in resp). Uninitialized + // resp.SizeBytes() correctly returns the metadata size only (since the read + // buffer is empty). + maxDataReadSize := uint64(f.client.maxMessageSize) - uint64(resp.SizeBytes()) + return chunkify(maxDataReadSize, dst, func(buf []byte, curOff uint64) (uint64, error) { + req := PReadReq{ + Offset: offset + curOff, + FD: f.fd, + Count: uint32(len(buf)), + } - resp := PReadResp{ // This will be unmarshalled into. Already set Buf so that we don't need to // allocate a temporary buffer during unmarshalling. // PReadResp.UnmarshalBytes expects this to be set. - Buf: dst, - } - - ctx.UninterruptibleSleepStart(false) - err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) - ctx.UninterruptibleSleepFinish(false) - return uint64(resp.NumBytes), err + resp.Buf = buf + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return uint64(resp.NumBytes), err + }) } // Write makes the PWrite RPC. func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) { - req := PWriteReq{ - Offset: primitive.Uint64(offset), - FD: f.fd, - NumBytes: primitive.Uint32(len(src)), - Buf: src, - } + var req PWriteReq + // maxDataWriteSize represents the maximum amount of data we can write at + // once (maximum message size - metadata size present in req). Uninitialized + // req.SizeBytes() correctly returns the metadata size only (since the write + // buffer is empty). + maxDataWriteSize := uint64(f.client.maxMessageSize) - uint64(req.SizeBytes()) + return chunkify(maxDataWriteSize, src, func(buf []byte, curOff uint64) (uint64, error) { + req = PWriteReq{ + Offset: primitive.Uint64(offset + curOff), + FD: f.fd, + NumBytes: primitive.Uint32(len(buf)), + Buf: buf, + } - var resp PWriteResp - ctx.UninterruptibleSleepStart(false) - err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) - ctx.UninterruptibleSleepFinish(false) - return resp.Count, err + var resp PWriteResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Count, err + }) } // MkdirAt makes the MkdirAt RPC. |