diff options
Diffstat (limited to 'pkg/p9/buffer.go')
-rw-r--r-- | pkg/p9/buffer.go | 262 |
1 files changed, 262 insertions, 0 deletions
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go new file mode 100644 index 000000000..fc65d2c5f --- /dev/null +++ b/pkg/p9/buffer.go @@ -0,0 +1,262 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "encoding/binary" +) + +// encoder is used for messages and 9P primitives. +type encoder interface { + // Decode decodes from the given buffer. + // + // This may not fail, exhaustion will be recorded in the buffer. + Decode(b *buffer) + + // Encode encodes to the given buffer. + // + // This may not fail. + Encode(b *buffer) +} + +// order is the byte order used for encoding. +var order = binary.LittleEndian + +// buffer is a slice that is consumed. +// +// This is passed to the encoder methods. +type buffer struct { + // data is the underlying data. This may grow during Encode. + data []byte + + // overflow indicates whether an overflow has occurred. + overflow bool +} + +// append appends n bytes to the buffer and returns a slice pointing to the +// newly appended bytes. +func (b *buffer) append(n int) []byte { + b.data = append(b.data, make([]byte, n)...) + return b.data[len(b.data)-n:] +} + +// consume consumes n bytes from the buffer. +func (b *buffer) consume(n int) ([]byte, bool) { + if !b.has(n) { + b.markOverrun() + return nil, false + } + rval := b.data[:n] + b.data = b.data[n:] + return rval, true +} + +// has returns true if n bytes are available. +func (b *buffer) has(n int) bool { + return len(b.data) >= n +} + +// markOverrun immediately marks this buffer as overrun. +// +// This is used by ReadString, since some invalid data implies the rest of the +// buffer is no longer valid either. +func (b *buffer) markOverrun() { + b.overflow = true +} + +// isOverrun returns true if this buffer has run past the end. +func (b *buffer) isOverrun() bool { + return b.overflow +} + +// Read8 reads a byte from the buffer. +func (b *buffer) Read8() uint8 { + v, ok := b.consume(1) + if !ok { + return 0 + } + return uint8(v[0]) +} + +// Read16 reads a 16-bit value from the buffer. +func (b *buffer) Read16() uint16 { + v, ok := b.consume(2) + if !ok { + return 0 + } + return order.Uint16(v) +} + +// Read32 reads a 32-bit value from the buffer. +func (b *buffer) Read32() uint32 { + v, ok := b.consume(4) + if !ok { + return 0 + } + return order.Uint32(v) +} + +// Read64 reads a 64-bit value from the buffer. +func (b *buffer) Read64() uint64 { + v, ok := b.consume(8) + if !ok { + return 0 + } + return order.Uint64(v) +} + +// ReadQIDType reads a QIDType value. +func (b *buffer) ReadQIDType() QIDType { + return QIDType(b.Read8()) +} + +// ReadTag reads a Tag value. +func (b *buffer) ReadTag() Tag { + return Tag(b.Read16()) +} + +// ReadFID reads a FID value. +func (b *buffer) ReadFID() FID { + return FID(b.Read32()) +} + +// ReadUID reads a UID value. +func (b *buffer) ReadUID() UID { + return UID(b.Read32()) +} + +// ReadGID reads a GID value. +func (b *buffer) ReadGID() GID { + return GID(b.Read32()) +} + +// ReadPermissions reads a file mode value and applies the mask for permissions. +func (b *buffer) ReadPermissions() FileMode { + return b.ReadFileMode() & PermissionsMask +} + +// ReadFileMode reads a file mode value. +func (b *buffer) ReadFileMode() FileMode { + return FileMode(b.Read32()) +} + +// ReadOpenFlags reads an OpenFlags. +func (b *buffer) ReadOpenFlags() OpenFlags { + return OpenFlags(b.Read32()) +} + +// ReadConnectFlags reads a ConnectFlags. +func (b *buffer) ReadConnectFlags() ConnectFlags { + return ConnectFlags(b.Read32()) +} + +// ReadMsgType writes a MsgType. +func (b *buffer) ReadMsgType() MsgType { + return MsgType(b.Read8()) +} + +// ReadString deserializes a string. +func (b *buffer) ReadString() string { + l := b.Read16() + if !b.has(int(l)) { + // Mark the buffer as corrupted. + b.markOverrun() + return "" + } + + bs := make([]byte, l) + for i := 0; i < int(l); i++ { + bs[i] = byte(b.Read8()) + } + return string(bs) +} + +// Write8 writes a byte to the buffer. +func (b *buffer) Write8(v uint8) { + b.append(1)[0] = byte(v) +} + +// Write16 writes a 16-bit value to the buffer. +func (b *buffer) Write16(v uint16) { + order.PutUint16(b.append(2), v) +} + +// Write32 writes a 32-bit value to the buffer. +func (b *buffer) Write32(v uint32) { + order.PutUint32(b.append(4), v) +} + +// Write64 writes a 64-bit value to the buffer. +func (b *buffer) Write64(v uint64) { + order.PutUint64(b.append(8), v) +} + +// WriteQIDType writes a QIDType value. +func (b *buffer) WriteQIDType(qidType QIDType) { + b.Write8(uint8(qidType)) +} + +// WriteTag writes a Tag value. +func (b *buffer) WriteTag(tag Tag) { + b.Write16(uint16(tag)) +} + +// WriteFID writes a FID value. +func (b *buffer) WriteFID(fid FID) { + b.Write32(uint32(fid)) +} + +// WriteUID writes a UID value. +func (b *buffer) WriteUID(uid UID) { + b.Write32(uint32(uid)) +} + +// WriteGID writes a GID value. +func (b *buffer) WriteGID(gid GID) { + b.Write32(uint32(gid)) +} + +// WritePermissions applies a permissions mask and writes the FileMode. +func (b *buffer) WritePermissions(perm FileMode) { + b.WriteFileMode(perm & PermissionsMask) +} + +// WriteFileMode writes a FileMode. +func (b *buffer) WriteFileMode(mode FileMode) { + b.Write32(uint32(mode)) +} + +// WriteOpenFlags writes an OpenFlags. +func (b *buffer) WriteOpenFlags(flags OpenFlags) { + b.Write32(uint32(flags)) +} + +// WriteConnectFlags writes a ConnectFlags. +func (b *buffer) WriteConnectFlags(flags ConnectFlags) { + b.Write32(uint32(flags)) +} + +// WriteMsgType writes a MsgType. +func (b *buffer) WriteMsgType(t MsgType) { + b.Write8(uint8(t)) +} + +// WriteString serializes the given string. +func (b *buffer) WriteString(s string) { + b.Write16(uint16(len(s))) + for i := 0; i < len(s); i++ { + b.Write8(byte(s[i])) + } +} |