// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package safemem import ( "bytes" "fmt" "unsafe" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/gohacks" ) // A BlockSeq represents a sequence of Blocks, each of which has non-zero // length. // // BlockSeqs are immutable and may be copied by value. The zero value of // BlockSeq represents an empty sequence. type BlockSeq struct { // If length is 0, then the BlockSeq is empty. Invariants: data == 0; // offset == 0; limit == 0. // // If length is -1, then the BlockSeq represents the single Block{data, // limit, false}. Invariants: offset == 0; limit > 0; limit does not // overflow the range of an int. // // If length is -2, then the BlockSeq represents the single Block{data, // limit, true}. Invariants: offset == 0; limit > 0; limit does not // overflow the range of an int. // // Otherwise, length >= 2, and the BlockSeq represents the `length` Blocks // in the array of Blocks starting at address `data`, starting at `offset` // bytes into the first Block and limited to the following `limit` bytes. // Invariants: data != 0; offset < len(data[0]); limit > 0; offset+limit <= // the combined length of all Blocks in the array; the first Block in the // array has non-zero length. // // length is never 1; sequences consisting of a single Block are always // stored inline (with length < 0). data unsafe.Pointer length int offset int limit uint64 } // BlockSeqOf returns a BlockSeq representing the single Block b. func BlockSeqOf(b Block) BlockSeq { if b.length == 0 { return BlockSeq{} } bs := BlockSeq{ data: b.start, length: -1, limit: uint64(b.length), } if b.needSafecopy { bs.length = -2 } return bs } // BlockSeqFromSlice returns a BlockSeq representing all Blocks in slice. // If slice contains Blocks with zero length, BlockSeq will skip them during // iteration. // // Whether the returned BlockSeq shares memory with slice is unspecified; // clients should avoid mutating slices passed to BlockSeqFromSlice. // // Preconditions: The combined length of all Blocks in slice <= math.MaxUint64. func BlockSeqFromSlice(slice []Block) BlockSeq { slice = skipEmpty(slice) var limit uint64 for _, b := range slice { sum := limit + uint64(b.Len()) if sum < limit { panic("BlockSeq length overflows uint64") } limit = sum } return blockSeqFromSliceLimited(slice, limit) } // Preconditions: // * The combined length of all Blocks in slice <= limit. // * If len(slice) != 0, the first Block in slice has non-zero length and // limit > 0. func blockSeqFromSliceLimited(slice []Block, limit uint64) BlockSeq { switch len(slice) { case 0: return BlockSeq{} case 1: return BlockSeqOf(slice[0].TakeFirst64(limit)) default: return BlockSeq{ data: unsafe.Pointer(&slice[0]), length: len(slice), limit: limit, } } } func skipEmpty(slice []Block) []Block { for i, b := range slice { if b.Len() != 0 { return slice[i:] } } return nil } // IsEmpty returns true if bs contains no Blocks. // // Invariants: bs.IsEmpty() == (bs.NumBlocks() == 0) == (bs.NumBytes() == 0). // (Of these, prefer to use bs.IsEmpty().) func (bs BlockSeq) IsEmpty() bool { return bs.length == 0 } // NumBlocks returns the number of Blocks in bs. func (bs BlockSeq) NumBlocks() int { // In general, we have to count: if bs represents a windowed slice then the // slice may contain Blocks with zero length, and bs.length may be larger // than the actual number of Blocks due to bs.limit. var n int for !bs.IsEmpty() { n++ bs = bs.Tail() } return n } // NumBytes returns the sum of Block.Len() for all Blocks in bs. func (bs BlockSeq) NumBytes() uint64 { return bs.limit } // Head returns the first Block in bs. // // Preconditions: !bs.IsEmpty(). func (bs BlockSeq) Head() Block { if bs.length == 0 { panic("empty BlockSeq") } if bs.length < 0 { return bs.internalBlock() } return (*Block)(bs.data).DropFirst(bs.offset).TakeFirst64(bs.limit) } // Preconditions: bs.length < 0. func (bs BlockSeq) internalBlock() Block { return Block{ start: bs.data, length: int(bs.limit), needSafecopy: bs.length == -2, } } // Tail returns a BlockSeq consisting of all Blocks in bs after the first. // // Preconditions: !bs.IsEmpty(). func (bs BlockSeq) Tail() BlockSeq { if bs.length == 0 { panic("empty BlockSeq") } if bs.length < 0 { return BlockSeq{} } head := (*Block)(bs.data).DropFirst(bs.offset) headLen := uint64(head.Len()) if headLen >= bs.limit { // The head Block exhausts the limit, so the tail is empty. return BlockSeq{} } var extSlice []Block extSliceHdr := (*gohacks.SliceHeader)(unsafe.Pointer(&extSlice)) extSliceHdr.Data = bs.data extSliceHdr.Len = bs.length extSliceHdr.Cap = bs.length tailSlice := skipEmpty(extSlice[1:]) tailLimit := bs.limit - headLen return blockSeqFromSliceLimited(tailSlice, tailLimit) } // DropFirst returns a BlockSeq equivalent to bs, but with the first n bytes // omitted. If n > bs.NumBytes(), DropFirst returns an empty BlockSeq. // // Preconditions: n >= 0. func (bs BlockSeq) DropFirst(n int) BlockSeq { if n < 0 { panic(fmt.Sprintf("invalid n: %d", n)) } return bs.DropFirst64(uint64(n)) } // DropFirst64 is equivalent to DropFirst but takes an uint64. func (bs BlockSeq) DropFirst64(n uint64) BlockSeq { if n >= bs.limit { return BlockSeq{} } for { // Calling bs.Head() here is surprisingly expensive, so inline getting // the head's length. var headLen uint64 if bs.length < 0 { headLen = bs.limit } else { headLen = uint64((*Block)(bs.data).Len() - bs.offset) } if n < headLen { // Dropping ends partway through the head Block. if bs.length < 0 { return BlockSeqOf(bs.internalBlock().DropFirst64(n)) } bs.offset += int(n) bs.limit -= n return bs } n -= headLen bs = bs.Tail() } } // TakeFirst returns a BlockSeq equivalent to the first n bytes of bs. If n > // bs.NumBytes(), TakeFirst returns a BlockSeq equivalent to bs. // // Preconditions: n >= 0. func (bs BlockSeq) TakeFirst(n int) BlockSeq { if n < 0 { panic(fmt.Sprintf("invalid n: %d", n)) } return bs.TakeFirst64(uint64(n)) } // TakeFirst64 is equivalent to TakeFirst but takes a uint64. func (bs BlockSeq) TakeFirst64(n uint64) BlockSeq { if n == 0 { return BlockSeq{} } if bs.limit > n { bs.limit = n } return bs } // String implements fmt.Stringer.String. func (bs BlockSeq) String() string { var buf bytes.Buffer buf.WriteByte('[') var sep string for !bs.IsEmpty() { buf.WriteString(sep) sep = " " buf.WriteString(bs.Head().String()) bs = bs.Tail() } buf.WriteByte(']') return buf.String() } // CopySeq copies srcs.NumBytes() or dsts.NumBytes() bytes, whichever is less, // from srcs to dsts and returns the number of bytes copied. // // If srcs and dsts overlap, the data stored in dsts is unspecified. func CopySeq(dsts, srcs BlockSeq) (uint64, error) { var done uint64 for !dsts.IsEmpty() && !srcs.IsEmpty() { dst := dsts.Head() src := srcs.Head() n, err := Copy(dst, src) done += uint64(n) if err != nil { return done, err } dsts = dsts.DropFirst(n) srcs = srcs.DropFirst(n) } return done, nil } // ZeroSeq sets all bytes in dsts to 0 and returns the number of bytes zeroed. func ZeroSeq(dsts BlockSeq) (uint64, error) { var done uint64 for !dsts.IsEmpty() { n, err := Zero(dsts.Head()) done += uint64(n) if err != nil { return done, err } dsts = dsts.DropFirst(n) } return done, nil } // IovecsFromBlockSeq returns a []unix.Iovec representing seq. func IovecsFromBlockSeq(bs BlockSeq) []unix.Iovec { iovs := make([]unix.Iovec, 0, bs.NumBlocks()) for ; !bs.IsEmpty(); bs = bs.Tail() { b := bs.Head() iovs = append(iovs, unix.Iovec{ Base: &b.ToSlice()[0], Len: uint64(b.Len()), }) // We don't need to care about b.NeedSafecopy(), because the host // kernel will handle such address ranges just fine (by returning // EFAULT). } return iovs }