summaryrefslogtreecommitdiffhomepage
path: root/pkg/merkletree/merkletree_test.go
diff options
context:
space:
mode:
authorChong Cai <chongc@google.com>2020-10-12 17:28:20 -0700
committergVisor bot <gvisor-bot@google.com>2020-10-12 17:30:14 -0700
commitef90fe173380a8d769c699aec08737ef56f43c3e (patch)
tree5ef76cba87b92bb950427976b5b7cea158ef2418 /pkg/merkletree/merkletree_test.go
parente7bbe70f79aa9308c2eb54b057ee5779b22f478e (diff)
Change Merkle tree library to use ReaderAt
Merkle tree library was originally using Read/Seek to access data and tree, since the parameters are io.ReadSeeker. This could cause race conditions if multiple threads accesses the same fd to read. Here we change to use ReaderAt, and implement it with PRead to make it thread safe. PiperOrigin-RevId: 336779260
Diffstat (limited to 'pkg/merkletree/merkletree_test.go')
-rw-r--r--pkg/merkletree/merkletree_test.go52
1 files changed, 15 insertions, 37 deletions
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index bb11ec844..e1350ebda 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -106,58 +106,36 @@ func (brw *bytesReadWriter) Write(p []byte) (int, error) {
return len(p), nil
}
-func (brw *bytesReadWriter) Read(p []byte) (int, error) {
- if brw.readPos >= len(brw.bytes) {
- return 0, io.EOF
- }
- bytesRead := copy(p, brw.bytes[brw.readPos:])
- brw.readPos += bytesRead
- if bytesRead < len(p) {
+func (brw *bytesReadWriter) ReadAt(p []byte, off int64) (int, error) {
+ bytesRead := copy(p, brw.bytes[off:])
+ if bytesRead == 0 {
return bytesRead, io.EOF
}
return bytesRead, nil
}
-func (brw *bytesReadWriter) Seek(offset int64, whence int) (int64, error) {
- off := offset
- if whence == io.SeekCurrent {
- off += int64(brw.readPos)
- }
- if whence == io.SeekEnd {
- off += int64(len(brw.bytes))
- }
- if off < 0 {
- panic("seek with negative offset")
- }
- if off >= int64(len(brw.bytes)) {
- return 0, io.EOF
- }
- brw.readPos = int(off)
- return off, nil
-}
-
func TestGenerate(t *testing.T) {
// The input data has size dataSize. It starts with the data in startWith,
// and all other bytes are zeroes.
testCases := []struct {
data []byte
- expectedRoot []byte
+ expectedHash []byte
}{
{
data: bytes.Repeat([]byte{0}, usermem.PageSize),
- expectedRoot: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
+ expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
},
{
data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
- expectedRoot: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
+ expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
},
{
data: []byte{'a'},
- expectedRoot: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
+ expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
},
{
data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
- expectedRoot: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
+ expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
},
}
@@ -183,13 +161,13 @@ func TestGenerate(t *testing.T) {
bytes: tc.data,
}
}
- root, err := Generate(&params)
+ hash, err := Generate(&params)
if err != nil {
t.Fatalf("Got err: %v, want nil", err)
}
- if !bytes.Equal(root, tc.expectedRoot) {
- t.Errorf("Got root: %v, want %v", root, tc.expectedRoot)
+ if !bytes.Equal(hash, tc.expectedHash) {
+ t.Errorf("Got hash: %v, want %v", hash, tc.expectedHash)
}
}
})
@@ -390,7 +368,7 @@ func TestVerify(t *testing.T) {
bytes: data,
}
}
- root, err := Generate(&genParams)
+ hash, err := Generate(&genParams)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
@@ -409,7 +387,7 @@ func TestVerify(t *testing.T) {
GID: defaultGID,
ReadOffset: tc.verifyStart,
ReadSize: tc.verifySize,
- ExpectedRoot: root,
+ Expected: hash,
DataAndTreeInSameFile: dataAndTreeInSameFile,
}
if tc.modifyName {
@@ -478,7 +456,7 @@ func TestVerifyRandom(t *testing.T) {
bytes: data,
}
}
- root, err := Generate(&genParams)
+ hash, err := Generate(&genParams)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
@@ -499,7 +477,7 @@ func TestVerifyRandom(t *testing.T) {
GID: defaultGID,
ReadOffset: start,
ReadSize: size,
- ExpectedRoot: root,
+ Expected: hash,
DataAndTreeInSameFile: dataAndTreeInSameFile,
}