summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fs
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/fs')
-rw-r--r--pkg/sentry/fs/file.go2
-rw-r--r--pkg/sentry/fs/lock/lock.go41
-rw-r--r--pkg/sentry/fs/lock/lock_set_functions.go8
-rw-r--r--pkg/sentry/fs/lock/lock_test.go111
4 files changed, 73 insertions, 89 deletions
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index 2a278fbe3..ca41520b4 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -146,7 +146,7 @@ func (f *File) DecRef() {
f.DecRefWithDestructor(func() {
// Drop BSD style locks.
lockRng := lock.LockRange{Start: 0, End: lock.LockEOF}
- f.Dirent.Inode.LockCtx.BSD.UnlockRegion(lock.UniqueID(f.UniqueID), lockRng)
+ f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng)
// Release resources held by the FileOperations.
f.FileOperations.Release()
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 926538d90..8a5d9c7eb 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -62,7 +62,7 @@ import (
type LockType int
// UniqueID is a unique identifier of the holder of a regional file lock.
-type UniqueID uint64
+type UniqueID interface{}
const (
// ReadLock describes a POSIX regional file lock to be taken
@@ -98,12 +98,7 @@ type Lock struct {
// If len(Readers) > 0 then HasWriter must be false.
Readers map[UniqueID]bool
- // HasWriter indicates that this is a write lock held by a single
- // UniqueID.
- HasWriter bool
-
- // Writer is only valid if HasWriter is true. It identifies a
- // single write lock holder.
+ // Writer holds the writer unique ID. It's nil if there are no writers.
Writer UniqueID
}
@@ -186,7 +181,6 @@ func makeLock(uid UniqueID, t LockType) Lock {
case ReadLock:
value.Readers[uid] = true
case WriteLock:
- value.HasWriter = true
value.Writer = uid
default:
panic(fmt.Sprintf("makeLock: invalid lock type %d", t))
@@ -196,10 +190,7 @@ func makeLock(uid UniqueID, t LockType) Lock {
// isHeld returns true if uid is a holder of Lock.
func (l Lock) isHeld(uid UniqueID) bool {
- if l.HasWriter && l.Writer == uid {
- return true
- }
- return l.Readers[uid]
+ return l.Writer == uid || l.Readers[uid]
}
// lock sets uid as a holder of a typed lock on Lock.
@@ -214,20 +205,20 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
}
// We cannot downgrade a write lock to a read lock unless the
// uid is the same.
- if l.HasWriter {
+ if l.Writer != nil {
if l.Writer != uid {
panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer))
}
// Ensure that there is only one reader if upgrading.
l.Readers = make(map[UniqueID]bool)
// Ensure that there is no longer a writer.
- l.HasWriter = false
+ l.Writer = nil
}
l.Readers[uid] = true
return
case WriteLock:
// If we are already the writer, then this is a no-op.
- if l.HasWriter && l.Writer == uid {
+ if l.Writer == uid {
return
}
// We can only upgrade a read lock to a write lock if there
@@ -243,7 +234,6 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
}
// Ensure that there is only a writer.
l.Readers = make(map[UniqueID]bool)
- l.HasWriter = true
l.Writer = uid
default:
panic(fmt.Sprintf("lock: invalid lock type %d", t))
@@ -277,9 +267,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
switch t {
case ReadLock:
return l.lockable(r, func(value Lock) bool {
- // If there is no writer, there's no problem adding
- // another reader.
- if !value.HasWriter {
+ // If there is no writer, there's no problem adding another reader.
+ if value.Writer == nil {
return true
}
// If there is a writer, then it must be the same uid
@@ -289,10 +278,9 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
case WriteLock:
return l.lockable(r, func(value Lock) bool {
// If there are only readers.
- if !value.HasWriter {
- // Then this uid can only take a write lock if
- // this is a private upgrade, meaning that the
- // only reader is uid.
+ if value.Writer == nil {
+ // Then this uid can only take a write lock if this is a private
+ // upgrade, meaning that the only reader is uid.
return len(value.Readers) == 1 && value.Readers[uid]
}
// If the uid is already a writer on this region, then
@@ -304,7 +292,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
}
}
-// lock returns true if uid took a lock of type t on the entire range of LockRange.
+// lock returns true if uid took a lock of type t on the entire range of
+// LockRange.
//
// Preconditions: r.Start <= r.End (will panic otherwise).
func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
@@ -339,7 +328,7 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
seg, _ = l.SplitUnchecked(seg, r.End)
}
- // Set the lock on the segment. This is guaranteed to
+ // Set the lock on the segment. This is guaranteed to
// always be safe, given canLock above.
value := seg.ValuePtr()
value.lock(uid, t)
@@ -386,7 +375,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) {
value := seg.Value()
var remove bool
- if value.HasWriter && value.Writer == uid {
+ if value.Writer == uid {
// If we are unlocking a writer, then since there can
// only ever be one writer and no readers, then this
// lock should always be removed from the set.
diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go
index 8a3ace0c1..50a16e662 100644
--- a/pkg/sentry/fs/lock/lock_set_functions.go
+++ b/pkg/sentry/fs/lock/lock_set_functions.go
@@ -44,14 +44,9 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock)
return Lock{}, false
}
}
- if val1.HasWriter != val2.HasWriter {
+ if val1.Writer != val2.Writer {
return Lock{}, false
}
- if val1.HasWriter {
- if val1.Writer != val2.Writer {
- return Lock{}, false
- }
- }
return val1, true
}
@@ -62,7 +57,6 @@ func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock)
for k, v := range val.Readers {
val0.Readers[k] = v
}
- val0.HasWriter = val.HasWriter
val0.Writer = val.Writer
return val, val0
diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go
index ba002aeb7..fad90984b 100644
--- a/pkg/sentry/fs/lock/lock_test.go
+++ b/pkg/sentry/fs/lock/lock_test.go
@@ -42,9 +42,6 @@ func equals(e0, e1 []entry) bool {
if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) {
return false
}
- if e0[i].Lock.HasWriter != e1[i].Lock.HasWriter {
- return false
- }
if e0[i].Lock.Writer != e1[i].Lock.Writer {
return false
}
@@ -105,7 +102,7 @@ func TestCanLock(t *testing.T) {
LockRange: LockRange{2048, 3072},
},
{
- Lock: Lock{HasWriter: true, Writer: 1},
+ Lock: Lock{Writer: 1},
LockRange: LockRange{3072, 4096},
},
})
@@ -241,7 +238,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -254,7 +251,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -273,7 +270,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 4096},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -301,7 +298,7 @@ func TestSetLock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 4096},
},
{
@@ -318,7 +315,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -550,7 +547,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 4096},
},
{
@@ -594,7 +591,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 3072},
},
{
@@ -633,7 +630,7 @@ func TestSetLock(t *testing.T) {
// 0 1024 2048 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -663,11 +660,11 @@ func TestSetLock(t *testing.T) {
// 0 1024 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, LockEOF},
},
},
@@ -675,28 +672,30 @@ func TestSetLock(t *testing.T) {
}
for _, test := range tests {
- l := fill(test.before)
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
- r := LockRange{Start: test.start, End: test.end}
- success := l.lock(test.uid, test.lockType, r)
- var got []entry
- for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- got = append(got, entry{
- Lock: seg.Value(),
- LockRange: seg.Range(),
- })
- }
+ r := LockRange{Start: test.start, End: test.end}
+ success := l.lock(test.uid, test.lockType, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
- if success != test.success {
- t.Errorf("%s: setlock(%v, %+v, %d, %d) got success %v, want %v", test.name, test.before, r, test.uid, test.lockType, success, test.success)
- continue
- }
+ if success != test.success {
+ t.Errorf("setlock(%v, %+v, %d, %d) got success %v, want %v", test.before, r, test.uid, test.lockType, success, test.success)
+ return
+ }
- if success {
- if !equals(got, test.after) {
- t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after)
+ if success {
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
}
- }
+ })
}
}
@@ -782,7 +781,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -824,7 +823,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -837,7 +836,7 @@ func TestUnlock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -876,7 +875,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -889,7 +888,7 @@ func TestUnlock(t *testing.T) {
// 0 4096
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 4096},
},
},
@@ -906,7 +905,7 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 4096},
},
{
@@ -974,7 +973,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -991,7 +990,7 @@ func TestUnlock(t *testing.T) {
// 0 8 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 8},
},
{
@@ -1008,7 +1007,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -1025,7 +1024,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 8192 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -1041,19 +1040,21 @@ func TestUnlock(t *testing.T) {
}
for _, test := range tests {
- l := fill(test.before)
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
- r := LockRange{Start: test.start, End: test.end}
- l.unlock(test.uid, r)
- var got []entry
- for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- got = append(got, entry{
- Lock: seg.Value(),
- LockRange: seg.Range(),
- })
- }
- if !equals(got, test.after) {
- t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after)
- }
+ r := LockRange{Start: test.start, End: test.end}
+ l.unlock(test.uid, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
+ })
}
}