diff options
Diffstat (limited to 'pkg')
144 files changed, 3661 insertions, 2369 deletions
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index d1ca56370..b84d7c048 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -21,6 +21,7 @@ const ( F_SETFD = 2 F_GETFL = 3 F_SETFL = 4 + F_GETLK = 5 F_SETLK = 6 F_SETLKW = 7 F_SETOWN = 8 @@ -55,7 +56,7 @@ type Flock struct { _ [4]byte Start int64 Len int64 - Pid int32 + PID int32 _ [4]byte } diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go index fdfe31417..6f3d72e83 100644 --- a/pkg/coverage/coverage.go +++ b/pkg/coverage/coverage.go @@ -26,7 +26,6 @@ import ( "fmt" "io" "sort" - "sync/atomic" "testing" "gvisor.dev/gvisor/pkg/sync" @@ -69,12 +68,18 @@ var globalData struct { } // ClearCoverageData clears existing coverage data. +// +//go:norace func ClearCoverageData() { coverageMu.Lock() defer coverageMu.Unlock() + + // We do not use atomic operations while reading/writing to the counters, + // which would drastically degrade performance. Slight discrepancies due to + // racing is okay for the purposes of kcov. for _, counters := range coverdata.Cover.Counters { for index := 0; index < len(counters); index++ { - atomic.StoreUint32(&counters[index], 0) + counters[index] = 0 } } } @@ -114,6 +119,8 @@ var coveragePool = sync.Pool{ // ensure that each event is only reported once. Due to the limitations of Go // coverage tools, we reset the global coverage data every time this function is // run. +// +//go:norace func ConsumeCoverageData(w io.Writer) int { InitCoverageData() @@ -125,11 +132,14 @@ func ConsumeCoverageData(w io.Writer) int { for fileNum, file := range globalData.files { counters := coverdata.Cover.Counters[file] for index := 0; index < len(counters); index++ { - if atomic.LoadUint32(&counters[index]) == 0 { + // We do not use atomic operations while reading/writing to the counters, + // which would drastically degrade performance. Slight discrepancies due to + // racing is okay for the purposes of kcov. + if counters[index] == 0 { continue } // Non-zero coverage data found; consume it and report as a PC. - atomic.StoreUint32(&counters[index], 0) + counters[index] = 0 pc := globalData.syntheticPCs[fileNum][index] usermem.ByteOrder.PutUint64(pcBuffer[:], pc) n, err := w.Write(pcBuffer[:]) diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go index af8230a7d..387f713aa 100644 --- a/pkg/sentry/fs/fdpipe/pipe_state.go +++ b/pkg/sentry/fs/fdpipe/pipe_state.go @@ -34,7 +34,9 @@ func (p *pipeOperations) beforeSave() { } else if p.flags.Write { file, err := p.opener.NonBlockingOpen(context.Background(), fs.PermMask{Write: true}) if err != nil { - panic(fs.ErrSaveRejection{fmt.Errorf("write-only pipe end cannot be re-opened as %v: %v", p, err)}) + panic(&fs.ErrSaveRejection{ + Err: fmt.Errorf("write-only pipe end cannot be re-opened as %#v: %w", p, err), + }) } file.Close() } diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index a020da53b..44587bb37 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -144,7 +144,7 @@ type ErrSaveRejection struct { } // Error returns a sensible description of the save rejection error. -func (e ErrSaveRejection) Error() string { +func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported file system state: " + e.Err.Error() } diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go index a3402e343..141e3c27f 100644 --- a/pkg/sentry/fs/gofer/inode_state.go +++ b/pkg/sentry/fs/gofer/inode_state.go @@ -67,7 +67,9 @@ func (i *inodeFileState) beforeSave() { if i.sattr.Type == fs.RegularFile { uattr, err := i.unstableAttr(&dummyClockContext{context.Background()}) if err != nil { - panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.s.inodeMappings[i.sattr.InodeID], err)}) + panic(&fs.ErrSaveRejection{ + Err: fmt.Errorf("failed to get unstable atttribute of %s: %w", i.s.inodeMappings[i.sattr.InodeID], err), + }) } i.savedUAttr = &uattr } diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index ae3331737..4d3b216d8 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -41,6 +41,8 @@ go_library( ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/sync", "//pkg/waiter", diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go index 8a5d9c7eb..57686ce07 100644 --- a/pkg/sentry/fs/lock/lock.go +++ b/pkg/sentry/fs/lock/lock.go @@ -54,6 +54,8 @@ import ( "math" "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) @@ -83,6 +85,17 @@ const ( // offset 0 to LockEOF. const LockEOF = math.MaxUint64 +// OwnerInfo describes the owner of a lock. +// +// TODO(gvisor.dev/issue/5264): We may need to add other fields in the future +// (e.g., Linux's file_lock.fl_flags to support open file-descriptor locks). +// +// +stateify savable +type OwnerInfo struct { + // PID is the process ID of the lock owner. + PID int32 +} + // Lock is a regional file lock. It consists of either a single writer // or a set of readers. // @@ -92,14 +105,20 @@ const LockEOF = math.MaxUint64 // A Lock may be downgraded from a write lock to a read lock only if // the write lock's uid is the same as the read lock. // +// Accesses to Lock are synchronized through the Locks object to which it +// belongs. +// // +stateify savable type Lock struct { // Readers are the set of read lock holders identified by UniqueID. - // If len(Readers) > 0 then HasWriter must be false. - Readers map[UniqueID]bool + // If len(Readers) > 0 then Writer must be nil. + Readers map[UniqueID]OwnerInfo // Writer holds the writer unique ID. It's nil if there are no writers. Writer UniqueID + + // WriterInfo describes the writer. It is only meaningful if Writer != nil. + WriterInfo OwnerInfo } // Locks is a thread-safe wrapper around a LockSet. @@ -135,14 +154,14 @@ const ( // acquiring the lock in a non-blocking mode or "interrupted" if in a blocking mode. // Blocker is the interface used to provide blocking behavior, passing a nil Blocker // will result in non-blocking behavior. -func (l *Locks) LockRegion(uid UniqueID, t LockType, r LockRange, block Blocker) bool { +func (l *Locks) LockRegion(uid UniqueID, ownerPID int32, t LockType, r LockRange, block Blocker) bool { for { l.mu.Lock() // Blocking locks must run in a loop because we'll be woken up whenever an unlock event // happens for this lock. We will then attempt to take the lock again and if it fails // continue blocking. - res := l.locks.lock(uid, t, r) + res := l.locks.lock(uid, ownerPID, t, r) if !res && block != nil { e, ch := waiter.NewChannelEntry(nil) l.blockedQueue.EventRegister(&e, EventMaskAll) @@ -161,6 +180,14 @@ func (l *Locks) LockRegion(uid UniqueID, t LockType, r LockRange, block Blocker) } } +// LockRegionVFS1 is a wrapper around LockRegion for VFS1, which does not implement +// F_GETLK (and does not care about storing PIDs as a result). +// +// TODO(gvisor.dev/issue/1624): Delete. +func (l *Locks) LockRegionVFS1(uid UniqueID, t LockType, r LockRange, block Blocker) bool { + return l.LockRegion(uid, 0 /* ownerPID */, t, r, block) +} + // UnlockRegion attempts to release a lock for the uid on a region of a file. // This operation is always successful, even if there did not exist a lock on // the requested region held by uid in the first place. @@ -175,13 +202,14 @@ func (l *Locks) UnlockRegion(uid UniqueID, r LockRange) { // makeLock returns a new typed Lock that has either uid as its only reader // or uid as its only writer. -func makeLock(uid UniqueID, t LockType) Lock { - value := Lock{Readers: make(map[UniqueID]bool)} +func makeLock(uid UniqueID, ownerPID int32, t LockType) Lock { + value := Lock{Readers: make(map[UniqueID]OwnerInfo)} switch t { case ReadLock: - value.Readers[uid] = true + value.Readers[uid] = OwnerInfo{PID: ownerPID} case WriteLock: value.Writer = uid + value.WriterInfo = OwnerInfo{PID: ownerPID} default: panic(fmt.Sprintf("makeLock: invalid lock type %d", t)) } @@ -190,17 +218,20 @@ func makeLock(uid UniqueID, t LockType) Lock { // isHeld returns true if uid is a holder of Lock. func (l Lock) isHeld(uid UniqueID) bool { - return l.Writer == uid || l.Readers[uid] + if _, ok := l.Readers[uid]; ok { + return true + } + return l.Writer == uid } // lock sets uid as a holder of a typed lock on Lock. // // Preconditions: canLock is true for the range containing this Lock. -func (l *Lock) lock(uid UniqueID, t LockType) { +func (l *Lock) lock(uid UniqueID, ownerPID int32, t LockType) { switch t { case ReadLock: // If we are already a reader, then this is a no-op. - if l.Readers[uid] { + if _, ok := l.Readers[uid]; ok { return } // We cannot downgrade a write lock to a read lock unless the @@ -210,11 +241,11 @@ func (l *Lock) lock(uid UniqueID, t LockType) { panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer)) } // Ensure that there is only one reader if upgrading. - l.Readers = make(map[UniqueID]bool) + l.Readers = make(map[UniqueID]OwnerInfo) // Ensure that there is no longer a writer. l.Writer = nil } - l.Readers[uid] = true + l.Readers[uid] = OwnerInfo{PID: ownerPID} return case WriteLock: // If we are already the writer, then this is a no-op. @@ -228,13 +259,14 @@ func (l *Lock) lock(uid UniqueID, t LockType) { if readers != 1 { panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, too many readers %v", uid, l.Readers)) } - if !l.Readers[uid] { + if _, ok := l.Readers[uid]; !ok { panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, conflicting reader %v", uid, l.Readers)) } } // Ensure that there is only a writer. - l.Readers = make(map[UniqueID]bool) + l.Readers = make(map[UniqueID]OwnerInfo) l.Writer = uid + l.WriterInfo = OwnerInfo{PID: ownerPID} default: panic(fmt.Sprintf("lock: invalid lock type %d", t)) } @@ -247,7 +279,7 @@ func (l LockSet) lockable(r LockRange, check func(value Lock) bool) bool { // Get our starting point. seg := l.LowerBoundSegment(r.Start) for seg.Ok() && seg.Start() < r.End { - // Note that we don't care about overruning the end of the + // Note that we don't care about overrunning the end of the // last segment because if everything checks out we'll just // split the last segment. if !check(seg.Value()) { @@ -281,7 +313,7 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool { if value.Writer == nil { // Then this uid can only take a write lock if this is a private // upgrade, meaning that the only reader is uid. - return len(value.Readers) == 1 && value.Readers[uid] + return value.isOnlyReader(uid) } // If the uid is already a writer on this region, then // adding a write lock would be a no-op. @@ -292,11 +324,19 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool { } } +func (l *Lock) isOnlyReader(uid UniqueID) bool { + if len(l.Readers) != 1 { + return false + } + _, ok := l.Readers[uid] + return ok +} + // lock returns true if uid took a lock of type t on the entire range of // LockRange. // // Preconditions: r.Start <= r.End (will panic otherwise). -func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool { +func (l *LockSet) lock(uid UniqueID, ownerPID int32, t LockType, r LockRange) bool { if r.Start > r.End { panic(fmt.Sprintf("lock: r.Start %d > r.End %d", r.Start, r.End)) } @@ -317,7 +357,7 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool { seg, gap := l.Find(r.Start) if gap.Ok() { // Fill in the gap and get the next segment to modify. - seg = l.Insert(gap, gap.Range().Intersect(r), makeLock(uid, t)).NextSegment() + seg = l.Insert(gap, gap.Range().Intersect(r), makeLock(uid, ownerPID, t)).NextSegment() } else if seg.Start() < r.Start { // Get our first segment to modify. _, seg = l.Split(seg, r.Start) @@ -331,12 +371,12 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool { // Set the lock on the segment. This is guaranteed to // always be safe, given canLock above. value := seg.ValuePtr() - value.lock(uid, t) + value.lock(uid, ownerPID, t) // Fill subsequent gaps. gap = seg.NextGap() if gr := gap.Range().Intersect(r); gr.Length() > 0 { - seg = l.Insert(gap, gr, makeLock(uid, t)).NextSegment() + seg = l.Insert(gap, gr, makeLock(uid, ownerPID, t)).NextSegment() } else { seg = gap.NextSegment() } @@ -380,7 +420,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) { // only ever be one writer and no readers, then this // lock should always be removed from the set. remove = true - } else if value.Readers[uid] { + } else if _, ok := value.Readers[uid]; ok { // If uid is the last reader, then just remove the entire // segment. if len(value.Readers) == 1 { @@ -390,7 +430,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) { // affecting any other segment's readers. To do // this, we need to make a copy of the Readers map // and not add this uid. - newValue := Lock{Readers: make(map[UniqueID]bool)} + newValue := Lock{Readers: make(map[UniqueID]OwnerInfo)} for k, v := range value.Readers { if k != uid { newValue.Readers[k] = v @@ -451,3 +491,72 @@ func ComputeRange(start, length, offset int64) (LockRange, error) { // Offset is guaranteed to be positive at this point. return LockRange{Start: uint64(offset), End: end}, nil } + +// TestRegion checks whether the lock holder identified by uid can hold a lock +// of type t on range r. It returns a Flock struct representing this +// information as the F_GETLK fcntl does. +// +// Note that the PID returned in the flock structure is relative to the root PID +// namespace. It needs to be converted to the caller's PID namespace before +// returning to userspace. +// +// TODO(gvisor.dev/issue/5264): we don't support OFD locks through fcntl, which +// would return a struct with pid = -1. +func (l *Locks) TestRegion(ctx context.Context, uid UniqueID, t LockType, r LockRange) linux.Flock { + f := linux.Flock{Type: linux.F_UNLCK} + switch t { + case ReadLock: + l.testRegion(r, func(lock Lock, start, length uint64) bool { + if lock.Writer == nil || lock.Writer == uid { + return true + } + f.Type = linux.F_WRLCK + f.PID = lock.WriterInfo.PID + f.Start = int64(start) + f.Len = int64(length) + return false + }) + case WriteLock: + l.testRegion(r, func(lock Lock, start, length uint64) bool { + if lock.Writer == nil { + for k, v := range lock.Readers { + if k != uid { + // Stop at the first conflict detected. + f.Type = linux.F_RDLCK + f.PID = v.PID + f.Start = int64(start) + f.Len = int64(length) + return false + } + } + return true + } + if lock.Writer == uid { + return true + } + f.Type = linux.F_WRLCK + f.PID = lock.WriterInfo.PID + f.Start = int64(start) + f.Len = int64(length) + return false + }) + default: + panic(fmt.Sprintf("TestRegion: invalid lock type %d", t)) + } + return f +} + +func (l *Locks) testRegion(r LockRange, check func(lock Lock, start, length uint64) bool) { + l.mu.Lock() + defer l.mu.Unlock() + + seg := l.locks.LowerBoundSegment(r.Start) + for seg.Ok() && seg.Start() < r.End { + lock := seg.Value() + if !check(lock, seg.Start(), seg.End()-seg.Start()) { + // Stop at the first conflict detected. + return + } + seg = seg.NextSegment() + } +} diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go index 50a16e662..dcc17c0dc 100644 --- a/pkg/sentry/fs/lock/lock_set_functions.go +++ b/pkg/sentry/fs/lock/lock_set_functions.go @@ -40,7 +40,7 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock) return Lock{}, false } for k := range val1.Readers { - if !val2.Readers[k] { + if _, ok := val2.Readers[k]; !ok { return Lock{}, false } } @@ -53,11 +53,12 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock) func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock) { // Copy the segment so that split segments don't contain map references // to other segments. - val0 := Lock{Readers: make(map[UniqueID]bool)} + val0 := Lock{Readers: make(map[UniqueID]OwnerInfo)} for k, v := range val.Readers { val0.Readers[k] = v } val0.Writer = val.Writer + val0.WriterInfo = val.WriterInfo return val, val0 } diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go index fad90984b..9878c04e1 100644 --- a/pkg/sentry/fs/lock/lock_test.go +++ b/pkg/sentry/fs/lock/lock_test.go @@ -30,12 +30,12 @@ func equals(e0, e1 []entry) bool { } for i := range e0 { for k := range e0[i].Lock.Readers { - if !e1[i].Lock.Readers[k] { + if _, ok := e1[i].Lock.Readers[k]; !ok { return false } } for k := range e1[i].Lock.Readers { - if !e0[i].Lock.Readers[k] { + if _, ok := e0[i].Lock.Readers[k]; !ok { return false } } @@ -90,15 +90,15 @@ func TestCanLock(t *testing.T) { // 0 1024 2048 3072 4096 l := fill([]entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 3: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 3: OwnerInfo{}}}, LockRange: LockRange{2048, 3072}, }, { @@ -220,7 +220,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -266,7 +266,7 @@ func TestSetLock(t *testing.T) { // 0 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, { @@ -283,7 +283,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -302,7 +302,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{0, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -333,7 +333,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -351,7 +351,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -366,11 +366,11 @@ func TestSetLock(t *testing.T) { // 0 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -383,7 +383,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -398,15 +398,15 @@ func TestSetLock(t *testing.T) { // 0 4096 8192 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{4096, 8192}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{8192, LockEOF}, }, }, @@ -419,7 +419,7 @@ func TestSetLock(t *testing.T) { // 0 1024 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, LockEOF}, }, }, @@ -434,7 +434,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -447,7 +447,7 @@ func TestSetLock(t *testing.T) { // 0 4096 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, }, @@ -467,11 +467,11 @@ func TestSetLock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, LockEOF}, }, }, @@ -484,7 +484,7 @@ func TestSetLock(t *testing.T) { // 0 1024 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, LockEOF}, }, }, @@ -499,15 +499,15 @@ func TestSetLock(t *testing.T) { // 0 1024 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{1024, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -520,15 +520,15 @@ func TestSetLock(t *testing.T) { // 0 1024 2048 4096 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -543,7 +543,7 @@ func TestSetLock(t *testing.T) { // 0 1024 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { @@ -551,7 +551,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{1024, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -564,15 +564,15 @@ func TestSetLock(t *testing.T) { // 0 1024 2048 4096 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -587,7 +587,7 @@ func TestSetLock(t *testing.T) { // 0 1024 3072 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { @@ -595,7 +595,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{1024, 3072}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -608,11 +608,11 @@ func TestSetLock(t *testing.T) { // 0 1024 2048 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, }, @@ -634,15 +634,15 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{2048, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -676,7 +676,7 @@ func TestSetLock(t *testing.T) { l := fill(test.before) r := LockRange{Start: test.start, End: test.end} - success := l.lock(test.uid, test.lockType, r) + success := l.lock(test.uid, 0 /* ownerPID */, test.lockType, r) var got []entry for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { got = append(got, entry{ @@ -739,7 +739,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -752,7 +752,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -765,7 +765,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -797,7 +797,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -810,7 +810,7 @@ func TestUnlock(t *testing.T) { // 0 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -849,7 +849,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -862,7 +862,7 @@ func TestUnlock(t *testing.T) { // 0 4096 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, }, @@ -901,7 +901,7 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { @@ -909,7 +909,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{1024, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -922,11 +922,11 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -939,7 +939,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -952,15 +952,15 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{1024, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -977,7 +977,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -994,7 +994,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 8}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -1011,7 +1011,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -1028,11 +1028,11 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, LockRange: LockRange{4096, 8192}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{8192, LockEOF}, }, }, diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index e91fa26a4..b44117f40 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -194,16 +193,6 @@ func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions return mfd.inode.Stat(ctx, fs, opts) } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (mfd *masterFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return mfd.Locks().LockPOSIX(ctx, &mfd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (mfd *masterFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return mfd.Locks().UnlockPOSIX(ctx, &mfd.vfsfd, uid, start, length, whence) -} - // maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid. func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { switch cmd { diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go index 70c68cf0a..a0c5b5af5 100644 --- a/pkg/sentry/fsimpl/devpts/replica.go +++ b/pkg/sentry/fsimpl/devpts/replica.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -189,13 +188,3 @@ func (rfd *replicaFileDescription) Stat(ctx context.Context, opts vfs.StatOption fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem() return rfd.inode.Stat(ctx, fs, opts) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (rfd *replicaFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return rfd.Locks().LockPOSIX(ctx, &rfd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (rfd *replicaFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return rfd.Locks().UnlockPOSIX(ctx, &rfd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 0ad79b381..512b70ede 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -311,13 +310,3 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in fd.off = offset return offset, nil } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *directoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *directoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index 4a5539b37..5ad9befcd 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -154,13 +153,3 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt // TODO(b/134676337): Implement mmap(2). return syserror.ENODEV } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 3b5927702..9da01cba3 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -90,6 +90,7 @@ type createSyntheticOpts struct { // * d.isDir(). // * d does not already contain a child with the given name. func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { + now := d.fs.clock.Now().Nanoseconds() child := &dentry{ refs: 1, // held by d fs: d.fs, @@ -98,6 +99,10 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { uid: uint32(opts.kuid), gid: uint32(opts.kgid), blockSize: usermem.PageSize, // arbitrary + atime: now, + mtime: now, + ctime: now, + btime: now, readFD: -1, writeFD: -1, mmapFD: -1, diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 91d5dc174..8f95473b6 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -36,16 +36,26 @@ import ( // Sync implements vfs.FilesystemImpl.Sync. func (fs *filesystem) Sync(ctx context.Context) error { // Snapshot current syncable dentries and special file FDs. + fs.renameMu.RLock() fs.syncMu.Lock() ds := make([]*dentry, 0, len(fs.syncableDentries)) for d := range fs.syncableDentries { + // It's safe to use IncRef here even though fs.syncableDentries doesn't + // hold references since we hold fs.renameMu. Note that we can't use + // TryIncRef since cached dentries at zero references should still be + // synced. d.IncRef() ds = append(ds, d) } + fs.renameMu.RUnlock() sffds := make([]*specialFileFD, 0, len(fs.specialFileFDs)) for sffd := range fs.specialFileFDs { - sffd.vfsfd.IncRef() - sffds = append(sffds, sffd) + // As above, fs.specialFileFDs doesn't hold references. However, unlike + // dentries, an FD that has reached zero references can't be + // resurrected, so we can use TryIncRef. + if sffd.vfsfd.TryIncRef() { + sffds = append(sffds, sffd) + } } fs.syncMu.Unlock() diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 3cdb1e659..98f7bc52f 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1944,22 +1944,22 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { } // LockBSD implements vfs.FileDescriptionImpl.LockBSD. -func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { +func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { fd.lockLogging.Do(func() { log.Infof("File lock using gofer file handled internally.") }) - return fd.LockFD.LockBSD(ctx, uid, t, block) + return fd.LockFD.LockBSD(ctx, uid, ownerPID, t, block) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { fd.lockLogging.Do(func() { log.Infof("Range lock using gofer file handled internally.") }) - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) + return fd.Locks().LockPOSIX(ctx, uid, ownerPID, t, r, block) } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + return fd.Locks().UnlockPOSIX(ctx, uid, r) } diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 36a3f6810..05f11fbd5 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -810,13 +809,3 @@ func (f *fileDescription) EventUnregister(e *waiter.Entry) { func (f *fileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { return fdnotifier.NonBlockingPoll(int32(f.inode.hostFD), mask) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (f *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return f.Locks().LockPOSIX(ctx, &f.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (f *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return f.Locks().UnlockPOSIX(ctx, &f.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index f5c596fec..0f9e20a84 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -370,13 +369,3 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) return syserror.ERESTARTSYS } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (t *TTYFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, typ fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return t.Locks().LockPOSIX(ctx, &t.vfsfd, uid, typ, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (t *TTYFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return t.Locks().UnlockPOSIX(ctx, &t.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 485504995..65054b0ea 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -136,13 +135,3 @@ func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { // DynamicBytesFiles are immutable. return syserror.EPERM } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *DynamicBytesFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *DynamicBytesFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index f8dae22f8..e55111af0 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -275,13 +274,3 @@ func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptio func (fd *GenericDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { return fd.DirectoryFileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *GenericDirectoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 3492409b2..082fa6504 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -42,7 +42,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refsvfs2" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -820,13 +819,3 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 75be6129f..fdae163d1 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -518,16 +517,6 @@ func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error { // Release implements vfs.FileDescriptionImpl.Release. func (fd *memFD) Release(context.Context) {} -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} - // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -1110,13 +1099,3 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release(ctx context.Context) { fd.inode.DecRef(ctx) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *namespaceFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *namespaceFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 7ee6227a9..d6f076cd6 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -393,7 +393,7 @@ func TestProcSelf(t *testing.T) { t.Fatalf("CreateTask(): %v", err) } - collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{ + collector := s.WithTemporaryContext(task.AsyncContext()).ListDirents(&vfs.PathOperation{ Root: s.Root, Start: s.Root, Path: fspath.Parse("/proc/self/"), @@ -491,11 +491,11 @@ func TestTree(t *testing.T) { t.Fatalf("CreateTask(): %v", err) } // Add file to populate /proc/[pid]/fd and fdinfo directories. - task.FDTable().NewFDVFS2(task, 0, file, kernel.FDFlags{}) + task.FDTable().NewFDVFS2(task.AsyncContext(), 0, file, kernel.FDFlags{}) tasks = append(tasks, task) } - ctx := tasks[0] + ctx := tasks[0].AsyncContext() fd, err := s.VFS.OpenAt( ctx, auth.CredentialsFromContext(s.Ctx), diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go index 146c7fdfe..4393cc13b 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go @@ -140,35 +140,35 @@ func TestLocks(t *testing.T) { uid1 := 123 uid2 := 456 - if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, nil); err != nil { + if err := fd.Impl().LockBSD(ctx, uid1, 0 /* ownerPID */, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, nil); err != nil { + if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want) } if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil { t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err) } - if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil); err != nil { + if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, 0, 1, linux.SEEK_SET, nil); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 1, 2, linux.SEEK_SET, nil); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 1, End: 2}, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, 0, 1, linux.SEEK_SET, nil); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.WriteLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 0, 1, linux.SEEK_SET, nil), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil), syserror.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want) } - if err := fd.Impl().UnlockPOSIX(ctx, uid1, 0, 1, linux.SEEK_SET); err != nil { + if err := fd.Impl().UnlockPOSIX(ctx, uid1, lock.LockRange{Start: 0, End: 1}); err != nil { t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err) } } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 0c9c639d3..b32c54e20 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -36,7 +36,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -797,16 +796,6 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { return nil } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} - // Sync implements vfs.FileDescriptionImpl.Sync. It does nothing because all // filesystem state is in-memory. func (*fileDescription) Sync(context.Context) error { diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index a5171b5ad..8645078a0 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -660,7 +660,6 @@ func (d *dentry) readlink(ctx context.Context) (string, error) { type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - vfs.LockFD // d is the corresponding dentry to the fileDescription. d *dentry @@ -1104,14 +1103,29 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op return 0, syserror.EROFS } +// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { + return fd.lowerFD.LockBSD(ctx, ownerPID, t, block) +} + +// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. +func (fd *fileDescription) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { + return fd.lowerFD.UnlockBSD(ctx) +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { + return fd.lowerFD.LockPOSIX(ctx, uid, ownerPID, t, r, block) } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence) +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + return fd.lowerFD.UnlockPOSIX(ctx, uid, r) +} + +// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return fd.lowerFD.TestPOSIX(ctx, uid, t, r) } // FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 30d8b4355..798d6a9bd 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -66,7 +66,7 @@ func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry { // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. -func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { +func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, context.Context, error) { t.Helper() k, err := testutil.Boot() if err != nil { @@ -119,7 +119,7 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, root.DecRef(ctx) mntns.DecRef(ctx) }) - return vfsObj, root, task, nil + return vfsObj, root, task.AsyncContext(), nil } // openVerityAt opens a verity file. diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 0ee60569c..8a5b11d40 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -240,7 +240,6 @@ go_library( "//pkg/sentry/fs/lock", "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", - "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/fsimpl/timerfd", diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 7aba31587..a6afabb1c 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -153,20 +153,12 @@ func (f *FDTable) drop(ctx context.Context, file *fs.File) { // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { - // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the - // entire file. - err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET) + // Release any POSIX lock possibly held by the FDTable. + err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF}) if err != nil && err != syserror.ENOLCK { panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) } - // Generate inotify events. - ev := uint32(linux.IN_CLOSE_NOWRITE) - if file.IsWritable() { - ev = linux.IN_CLOSE_WRITE - } - file.Dentry().InotifyWithParent(ctx, ev, 0, vfs.PathEvent) - // Drop the table's reference. file.DecRef(ctx) } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 303ae8056..ef4e934a1 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -593,8 +593,8 @@ func (k *Kernel) flushWritesToFiles(ctx context.Context) error { // Wrap this error in ErrSaveRejection so that it will trigger a save // error, rather than a panic. This also allows us to distinguish Fsync // errors from state file errors in state.Save. - return fs.ErrSaveRejection{ - Err: fmt.Errorf("%q was not sufficiently synced: %v", name, err), + return &fs.ErrSaveRejection{ + Err: fmt.Errorf("%q was not sufficiently synced: %w", name, err), } } return nil diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 2c32d017d..71daa9f4b 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -27,7 +27,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index d5a91730d..3b6336e94 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -441,13 +440,3 @@ func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFr } return n, err } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *VFSPipeFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *VFSPipeFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 9419f2e95..ecbe8f920 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -69,7 +69,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time. // syserror.ErrInterrupted if t is interrupted. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) BlockWithDeadline(C chan struct{}, haveDeadline bool, deadline ktime.Time) error { +func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline ktime.Time) error { if !haveDeadline { return t.block(C, nil) } diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index fdadb52c0..e9da99067 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -216,7 +216,7 @@ func (ns *PIDNamespace) TaskWithID(tid ThreadID) *Task { return t } -// ThreadGroupWithID returns the thread group lead by the task with thread ID +// ThreadGroupWithID returns the thread group led by the task with thread ID // tid in PID namespace ns. If no task has that TID, or if the task with that // TID is not a thread group leader, ThreadGroupWithID returns nil. func (ns *PIDNamespace) ThreadGroupWithID(tid ThreadID) *ThreadGroup { @@ -292,6 +292,11 @@ func (ns *PIDNamespace) UserNamespace() *auth.UserNamespace { return ns.userns } +// Root returns the root PID namespace of ns. +func (ns *PIDNamespace) Root() *PIDNamespace { + return ns.owner.Root +} + // A threadGroupNode defines the relationship between a thread group and the // rest of the system. Conceptually, threadGroupNode is data belonging to the // owning TaskSet, as if TaskSet contained a field `nodes @@ -485,3 +490,8 @@ func (t *Task) Parent() *Task { func (t *Task) ThreadID() ThreadID { return t.tg.pidns.IDOfTask(t) } + +// TGIDInRoot returns t's TGID in the root PID namespace. +func (t *Task) TGIDInRoot() ThreadID { + return t.tg.pidns.owner.Root.IDOfThreadGroup(t.tg) +} diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go index 6efe5102b..73bfbea49 100644 --- a/pkg/sentry/mm/procfs.go +++ b/pkg/sentry/mm/procfs.go @@ -17,7 +17,6 @@ package mm import ( "bytes" "fmt" - "strings" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" @@ -165,12 +164,12 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI } if s != "" { // Per linux, we pad until the 74th character. - if pad := 73 - lineLen; pad > 0 { - b.WriteString(strings.Repeat(" ", pad)) + for pad := 73 - lineLen; pad > 0; pad-- { + b.WriteByte(' ') } b.WriteString(s) } - b.WriteString("\n") + b.WriteByte('\n') } // ReadSmapsDataInto is called by fsimpl/proc.smapsData.Generate to diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 675efdc7c..69e37330b 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -1055,18 +1055,11 @@ func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error { mm.activeMu.Lock() defer mm.activeMu.Unlock() - // Linux's mm/madvise.c:madvise_dontneed() => mm/memory.c:zap_page_range() - // is analogous to our mm.invalidateLocked(ar, true, true). We inline this - // here, with the special case that we synchronously decommit - // uniquely-owned (non-copy-on-write) pages for private anonymous vma, - // which is the common case for MADV_DONTNEED. Invalidating these pmas, and - // allowing them to be reallocated when touched again, increases pma - // fragmentation, which may significantly reduce performance for - // non-vectored I/O implementations. Also, decommitting synchronously - // ensures that Decommit immediately reduces host memory usage. + // This is invalidateLocked(invalidatePrivate=true, invalidateShared=true), + // with the additional wrinkle that we must refuse to invalidate pmas under + // mlocked vmas. var didUnmapAS bool pseg := mm.pmas.LowerBoundSegment(ar.Start) - mf := mm.mfp.MemoryFile() for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() { vma := vseg.ValuePtr() if vma.mlockMode != memmap.MLockNone { @@ -1081,20 +1074,8 @@ func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error { } } for pseg.Ok() && pseg.Start() < vsegAR.End { - pma := pseg.ValuePtr() - if pma.private && !mm.isPMACopyOnWriteLocked(vseg, pseg) { - psegAR := pseg.Range().Intersect(ar) - if vsegAR.IsSupersetOf(psegAR) && vma.mappable == nil { - if err := mf.Decommit(pseg.fileRangeOf(psegAR)); err == nil { - pseg = pseg.NextSegment() - continue - } - // If an error occurs, fall through to the general - // invalidation case below. - } - } pseg = mm.pmas.Isolate(pseg, vsegAR) - pma = pseg.ValuePtr() + pma := pseg.ValuePtr() if !didUnmapAS { // Unmap all of ar, not just pseg.Range(), to minimize host // syscalls. AddressSpace mappings must be removed before diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index b6ebe29d6..a8e6f172b 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -28,7 +28,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostfd", "//pkg/sentry/inet", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 5b868216d..17f59ba1f 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -377,10 +377,8 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 - case linux.IP_PKTINFO: - optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 9a2cac40b..f82c7c224 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -144,16 +143,6 @@ func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return int64(n), err } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) -} - type socketProviderVFS2 struct { family int } diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 70c561cce..2f913787b 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -15,7 +15,6 @@ package netfilter import ( - "bytes" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -220,18 +219,6 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) } - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), // A Protocol value of 0 indicates all protocols match. @@ -242,8 +229,11 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Src: tcpip.Address(iptip.Src[:]), SrcMask: tcpip.Address(iptip.SrcMask[:]), SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, + InputInterface: string(trimNullBytes(iptip.InputInterface[:])), + InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])), + InputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_IN != 0, + OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])), + OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])), OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, }, nil } @@ -254,12 +244,12 @@ func containsUnsupportedFields4(iptip linux.IPTIP) bool { // - Dst and DstMask // - Src and SrcMask // - The inverse destination IP check flag + // - InputInterface, InputInterfaceMask and its inverse. // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} + const flagMask = 0 // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags != 0 || + const inverseMask = linux.IPT_INV_DSTIP | linux.IPT_INV_SRCIP | + linux.IPT_INV_VIA_IN | linux.IPT_INV_VIA_OUT + return iptip.Flags&^flagMask != 0 || iptip.InverseFlags&^inverseMask != 0 } diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 5dbb604f0..263d9d3b5 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -15,7 +15,6 @@ package netfilter import ( - "bytes" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -223,18 +222,6 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) } - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), // In ip6tables a flag controls whether to check the protocol. @@ -245,8 +232,11 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { Src: tcpip.Address(iptip.Src[:]), SrcMask: tcpip.Address(iptip.SrcMask[:]), SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, + InputInterface: string(trimNullBytes(iptip.InputInterface[:])), + InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])), + InputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_IN != 0, + OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])), + OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])), OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0, }, nil } @@ -257,14 +247,13 @@ func containsUnsupportedFields6(iptip linux.IP6TIP) bool { // - Dst and DstMask // - Src and SrcMask // - The inverse destination IP check flag + // - InputInterface, InputInterfaceMask and its inverse. // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} - flagMask := uint8(linux.IP6T_F_PROTO) + const flagMask = linux.IP6T_F_PROTO // Disable any supported inverse flags. - inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags&^flagMask != 0 || + const inverseMask = linux.IP6T_INV_DSTIP | linux.IP6T_INV_SRCIP | + linux.IP6T_INV_VIA_IN | linux.IP6T_INV_VIA_OUT + return iptip.Flags&^flagMask != 0 || iptip.InverseFlags&^inverseMask != 0 || iptip.TOS != 0 } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 26bd1abd4..7ae18b2a3 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,6 +17,7 @@ package netfilter import ( + "bytes" "errors" "fmt" @@ -393,3 +394,11 @@ func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkP rev.Revision = maxSupported return rev, nil } + +func trimNullBytes(b []byte) []byte { + n := bytes.IndexByte(b, 0) + if n == -1 { + n = len(b) + } + return b[:n] +} diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 69d13745e..176fa6116 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -112,7 +112,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 352c51390..2740697b3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index c88d8268d..466d5395d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 1f926aa91..9313e1167 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -22,7 +22,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index 461d524e5..842036764 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -149,13 +148,3 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 22abca120..915134b41 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -28,7 +28,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", @@ -42,7 +41,6 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 22e128b96..7065a0e46 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -19,7 +19,7 @@ // be used to expose certain endpoints to the sentry while leaving others out, // for example, TCP endpoints and Unix-domain endpoints. // -// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside +// Lock ordering: netstack => mm: ioSequenceReadWriter copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. package netstack @@ -55,7 +55,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -194,7 +193,6 @@ var Metrics = tcpip.Stats{ RequestsReceivedUnknownTargetAddress: mustCreateMetric("/netstack/arp/requests_received_unknown_addr", "Number of ARP requests received with an unknown target address."), OutgoingRequestInterfaceHasNoLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_iface_has_no_addr", "Number of failed attempts to send an ARP request with an interface that has no network address."), OutgoingRequestBadLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_invalid_local_addr", "Number of failed attempts to send an ARP request with a provided local address that is invalid."), - OutgoingRequestNetworkUnreachableErrors: mustCreateMetric("/netstack/arp/outgoing_requests_network_unreachable", "Number of failed attempts to send an ARP request with a network unreachable error."), OutgoingRequestsDropped: mustCreateMetric("/netstack/arp/outgoing_requests_dropped", "Number of ARP requests which failed to write to a link-layer endpoint."), OutgoingRequestsSent: mustCreateMetric("/netstack/arp/outgoing_requests_sent", "Number of ARP requests sent."), RepliesReceived: mustCreateMetric("/netstack/arp/replies_received", "Number of ARP replies received."), @@ -440,45 +438,10 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write return int64(res.Count), nil } -// ioSequencePayload implements tcpip.Payload. -// -// t copies user memory bytes on demand based on the requested size. -type ioSequencePayload struct { - ctx context.Context - src usermem.IOSequence -} - -// FullPayload implements tcpip.Payloader.FullPayload -func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { - return i.Payload(int(i.src.NumBytes())) -} - -// Payload implements tcpip.Payloader.Payload. -func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { - if max := int(i.src.NumBytes()); size > max { - size = max - } - v := buffer.NewView(size) - if _, err := i.src.CopyIn(i.ctx, v); err != nil { - // EOF can be returned only if src is a file and this means it - // is in a splice syscall and the error has to be ignored. - if err == io.EOF { - return v, nil - } - return nil, tcpip.ErrBadAddress - } - return v, nil -} - -// DropFirst drops the first n bytes from underlying src. -func (i *ioSequencePayload) DropFirst(n int) { - i.src = i.src.DropFirst(int(n)) -} - // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } @@ -486,69 +449,40 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } -// readerPayload implements tcpip.Payloader. -// -// It allocates a view and reads from a reader on-demand, based on available -// capacity in the endpoint. -type readerPayload struct { - ctx context.Context - r io.Reader - count int64 - err error -} +var _ tcpip.Payloader = (*limitedPayloader)(nil) -// FullPayload implements tcpip.Payloader.FullPayload. -func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { - return r.Payload(int(r.count)) +type limitedPayloader struct { + io.LimitedReader } -// Payload implements tcpip.Payloader.Payload. -func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { - if size > int(r.count) { - size = int(r.count) - } - v := buffer.NewView(size) - n, err := r.r.Read(v) - if n > 0 { - // We ignore the error here. It may re-occur on subsequent - // reads, but for now we can enqueue some amount of data. - r.count -= int64(n) - return v[:n], nil - } - if err == syserror.ErrWouldBlock { - return nil, tcpip.ErrWouldBlock - } else if err != nil { - r.err = err // Save for propation. - return nil, tcpip.ErrBadAddress - } - - // There is no data and no error. Return an error, which will propagate - // r.err, which will be nil. This is the desired result: (0, nil). - return nil, tcpip.ErrBadAddress +func (l limitedPayloader) Len() int { + return int(l.N) } // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { - f := &readerPayload{ctx: ctx, r: r, count: count} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + f := limitedPayloader{ + LimitedReader: io.LimitedReader{ + R: r, + N: count, + }, + } + n, err := s.Endpoint.Write(&f, tcpip.WriteOptions{ // Reads may be destructive but should be very fast, // so we can't release the lock while copying data. Atomic: true, }) - if err == tcpip.ErrWouldBlock { - return n, syserror.ErrWouldBlock - } else if err != nil { - return int64(n), f.err // Propagate error. + if err == tcpip.ErrBadBuffer { + err = nil } - - return int64(n), nil + return n, syserr.TranslateNetstackError(err).ToError() } // Readiness returns a mask of ready events for socket s. @@ -912,7 +846,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + size, err := ep.SocketOptions().GetSendBufferSize() if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1681,8 +1615,16 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } + family, _, _ := s.Type() + // TODO(gvisor.dev/issue/5132): We currently do not support + // setting this option for unix sockets. + if family == linux.AF_UNIX { + return nil + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) + ep.SocketOptions().SetSendBufferSize(int64(v), true) + return nil case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1814,10 +1756,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam var v linux.Linger binary.Unmarshal(optVal[:linux.SizeOfLinger], usermem.ByteOrder, &v) - if v != (linux.Linger{}) { - socket.SetSockOptEmitUnimplementedEvent(t, name) - } - ep.SocketOptions().SetLinger(tcpip.LingerOption{ Enabled: v.OnOff != 0, Timeout: time.Second * time.Duration(v.Linger), @@ -2840,45 +2778,46 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b EndOfRecord: flags&linux.MSG_EOR != 0, } - v := &ioSequencePayload{t, src} - n, err := s.Endpoint.Write(v, opts) - dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= v.src.NumBytes() || dontWait) { - // Complete write. - return int(n), nil - } - if err != nil && (err != tcpip.ErrWouldBlock || dontWait) { - return int(n), syserr.TranslateNetstackError(err) - } - - // We'll have to block. Register for notification and keep trying to - // send all the data. - e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventOut) - defer s.EventUnregister(&e) - - v.DropFirst(int(n)) - total := n + r := src.Reader(t) + var ( + total int64 + entry waiter.Entry + ch <-chan struct{} + ) for { - n, err = s.Endpoint.Write(v, opts) - v.DropFirst(int(n)) + n, err := s.Endpoint.Write(r, opts) total += n - - if err != nil && err != tcpip.ErrWouldBlock && total == 0 { - return 0, syserr.TranslateNetstackError(err) + if flags&linux.MSG_DONTWAIT != 0 { + return int(total), syserr.TranslateNetstackError(err) } - - if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { - return int(total), nil - } - - if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { - return int(total), syserr.ErrTryAgain + switch err { + case nil: + if total == src.NumBytes() { + break + } + fallthrough + case tcpip.ErrWouldBlock: + if ch == nil { + // We'll have to block. Register for notification and keep trying to + // send all the data. + entry, ch = waiter.NewChannelEntry(nil) + s.EventRegister(&entry, waiter.EventOut) + defer s.EventUnregister(&entry) + } else { + // Don't wait immediately after registration in case more data + // became available between when we last checked and when we setup + // the notification. + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + return int(total), syserr.ErrTryAgain + } + // handleIOError will consume errors from t.Block if needed. + return int(total), syserr.FromError(err) + } } - // handleIOError will consume errors from t.Block if needed. - return int(total), syserr.FromError(err) + continue } + return int(total), syserr.TranslateNetstackError(err) } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 6f70b02fc..3bbdf552e 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -129,8 +128,8 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return 0, syserror.EOPNOTSUPP } - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } @@ -138,11 +137,11 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } // Accept implements the linux syscall accept(2) for sockets backed by @@ -262,13 +261,3 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9f7aca305..b011082dc 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -128,7 +128,7 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil, nil) return ep } @@ -173,7 +173,7 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil, nil) return ep } @@ -296,7 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } - ne.ops.InitHandler(ne) + ne.ops.InitHandler(ne, nil, nil) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.InitRefs() diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 0813ad87d..20fa8b874 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,7 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil, nil) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 099a56281..0e3889c6d 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -842,7 +842,6 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: case tcpip.ReceiveBufferSizeOption: default: log.Warningf("Unsupported socket option: %d", opt) @@ -850,6 +849,27 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +// IsUnixSocket implements tcpip.SocketOptionsHandler.IsUnixSocket. +func (e *baseEndpoint) IsUnixSocket() bool { + return true +} + +// GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize. +func (e *baseEndpoint) GetSendBufferSize() (int64, *tcpip.Error) { + e.Lock() + defer e.Unlock() + + if !e.Connected() { + return -1, tcpip.ErrNotConnected + } + + v := e.connected.SendMaxQueueSize() + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return v, nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -879,19 +899,6 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } return int(v), nil - case tcpip.SendBufferSizeOption: - e.Lock() - if !e.Connected() { - e.Unlock() - return -1, tcpip.ErrNotConnected - } - v := e.connected.SendMaxQueueSize() - e.Unlock() - if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported - } - return int(v), nil - case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 6c4ec55b2..32e5d2304 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -496,6 +496,9 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b return int(n), syserr.FromError(err) } + // Only send SCM Rights once (see net/unix/af_unix.c:unix_stream_sendmsg). + w.Control.Rights = nil + // We'll have to block. Register for notification and keep trying to // send all the data. e, ch := waiter.NewChannelEntry(nil) diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 27f705bb2..a7d4d7f1f 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -331,16 +330,6 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) -} - // providerVFS2 is a unix domain socket provider for VFS2. type providerVFS2 struct{} diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index a72df62f6..62d1e8f8b 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -503,8 +503,8 @@ var ARM64 = &kernel.SyscallTable{ 72: syscalls.Supported("pselect", Pselect), 73: syscalls.Supported("ppoll", Ppoll), 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), - 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 76: syscalls.Supported("splice", Splice), 77: syscalls.Supported("tee", Tee), 78: syscalls.Supported("readlinkat", Readlinkat), 79: syscalls.Supported("fstatat", Fstatat), diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index c33571f43..a6253626e 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -1014,12 +1014,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } if cmd == linux.F_SETLK { // Non-blocking lock, provide a nil lock.Blocker. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.ReadLock, rng, nil) { return 0, nil, syserror.EAGAIN } } else { // Blocking lock, pass in the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, t) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.ReadLock, rng, t) { return 0, nil, syserror.EINTR } } @@ -1030,12 +1030,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } if cmd == linux.F_SETLK { // Non-blocking lock, provide a nil lock.Blocker. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.WriteLock, rng, nil) { return 0, nil, syserror.EAGAIN } } else { // Blocking lock, pass in the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, t) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.WriteLock, rng, t) { return 0, nil, syserror.EINTR } } @@ -2167,24 +2167,24 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.LOCK_EX: if nonblocking { // Since we're nonblocking we pass a nil lock.Blocker implementation. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.WriteLock, rng, nil) { return 0, nil, syserror.EWOULDBLOCK } } else { // Because we're blocking we will pass the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, t) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.WriteLock, rng, t) { return 0, nil, syserror.EINTR } } case linux.LOCK_SH: if nonblocking { // Since we're nonblocking we pass a nil lock.Blocker implementation. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.ReadLock, rng, nil) { return 0, nil, syserror.EWOULDBLOCK } } else { // Because we're blocking we will pass the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, t) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.ReadLock, rng, t) { return 0, nil, syserror.EINTR } } diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 7dd9ef857..e39f074f2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -205,8 +205,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } err := tmpfs.AddSeals(file, args[2].Uint()) return 0, nil, err - case linux.F_SETLK, linux.F_SETLKW: - return 0, nil, posixLock(t, args, file, cmd) + case linux.F_SETLK: + return 0, nil, posixLock(t, args, file, false /* blocking */) + case linux.F_SETLKW: + return 0, nil, posixLock(t, args, file, true /* blocking */) + case linux.F_GETLK: + return 0, nil, posixTestLock(t, args, file) case linux.F_GETSIG: a := file.AsyncHandler() if a == nil { @@ -292,7 +296,49 @@ func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, } } -func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error { +func posixTestLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription) error { + // Copy in the lock request. + flockAddr := args[2].Pointer() + var flock linux.Flock + if _, err := flock.CopyIn(t, flockAddr); err != nil { + return err + } + var typ lock.LockType + switch flock.Type { + case linux.F_RDLCK: + typ = lock.ReadLock + case linux.F_WRLCK: + typ = lock.WriteLock + default: + return syserror.EINVAL + } + r, err := file.ComputeLockRange(t, uint64(flock.Start), uint64(flock.Len), flock.Whence) + if err != nil { + return err + } + + newFlock, err := file.TestPOSIX(t, t.FDTable(), typ, r) + if err != nil { + return err + } + newFlock.PID = translatePID(t.PIDNamespace().Root(), t.PIDNamespace(), newFlock.PID) + if _, err = newFlock.CopyOut(t, flockAddr); err != nil { + return err + } + return nil +} + +// translatePID translates a pid from one namespace to another. Note that this +// may race with task termination/creation, in which case the original task +// corresponding to pid may no longer exist. This is used to implement the +// F_GETLK fcntl, which has the same potential race in Linux as well (i.e., +// there is no synchronization between retrieving the lock PID and translating +// it). See fs/locks.c:posix_lock_to_flock. +func translatePID(old, new *kernel.PIDNamespace, pid int32) int32 { + return int32(new.IDOfTask(old.TaskWithID(kernel.ThreadID(pid)))) +} + +func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, blocking bool) error { // Copy in the lock request. flockAddr := args[2].Pointer() var flock linux.Flock @@ -301,25 +347,30 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip } var blocker lock.Blocker - if cmd == linux.F_SETLKW { + if blocking { blocker = t } + r, err := file.ComputeLockRange(t, uint64(flock.Start), uint64(flock.Len), flock.Whence) + if err != nil { + return err + } + switch flock.Type { case linux.F_RDLCK: if !file.IsReadable() { return syserror.EBADF } - return file.LockPOSIX(t, t.FDTable(), lock.ReadLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker) + return file.LockPOSIX(t, t.FDTable(), int32(t.TGIDInRoot()), lock.ReadLock, r, blocker) case linux.F_WRLCK: if !file.IsWritable() { return syserror.EBADF } - return file.LockPOSIX(t, t.FDTable(), lock.WriteLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker) + return file.LockPOSIX(t, t.FDTable(), int32(t.TGIDInRoot()), lock.WriteLock, r, blocker) case linux.F_UNLCK: - return file.UnlockPOSIX(t, t.FDTable(), uint64(flock.Start), uint64(flock.Len), flock.Whence) + return file.UnlockPOSIX(t, t.FDTable(), r) default: return syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go index b910b5a74..d1452a04d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/lock.go +++ b/pkg/sentry/syscalls/linux/vfs2/lock.go @@ -44,11 +44,11 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall switch operation { case linux.LOCK_EX: - if err := file.LockBSD(t, lock.WriteLock, blocker); err != nil { + if err := file.LockBSD(t, int32(t.TGIDInRoot()), lock.WriteLock, blocker); err != nil { return 0, nil, err } case linux.LOCK_SH: - if err := file.LockBSD(t, lock.ReadLock, blocker); err != nil { + if err := file.LockBSD(t, int32(t.TGIDInRoot()), lock.ReadLock, blocker); err != nil { return 0, nil, err } case linux.LOCK_UN: diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index b77b29dcc..c7417840f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -93,17 +93,11 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { n, err := file.Read(t, dst, opts) if err != syserror.ErrWouldBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } @@ -134,9 +128,6 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt } file.EventUnregister(&w) - if total > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return total, err } @@ -257,17 +248,11 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { n, err := file.PRead(t, dst, offset, opts) if err != syserror.ErrWouldBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } @@ -297,10 +282,6 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of } } file.EventUnregister(&w) - - if total > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return total, err } @@ -363,17 +344,11 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { n, err := file.Write(t, src, opts) if err != syserror.ErrWouldBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } return n, err } @@ -403,10 +378,6 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op } } file.EventUnregister(&w) - - if total > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } return total, err } @@ -527,17 +498,11 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { n, err := file.PWrite(t, src, offset, opts) if err != syserror.ErrWouldBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } @@ -567,10 +532,6 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o } } file.EventUnregister(&w) - - if total > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return total, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 1ee37e5a8..903169dc2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -220,7 +220,6 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys length := args[3].Int64() file := t.GetFileVFS2(fd) - if file == nil { return 0, nil, syserror.EBADF } @@ -229,23 +228,18 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if !file.IsWritable() { return 0, nil, syserror.EBADF } - if mode != 0 { return 0, nil, syserror.ENOTSUP } - if offset < 0 || length <= 0 { return 0, nil, syserror.EINVAL } size := offset + length - if size < 0 { return 0, nil, syserror.EFBIG } - limit := limits.FromContext(t).Get(limits.FileSize).Cur - if uint64(size) >= limit { t.SendSignal(&arch.SignalInfo{ Signo: int32(linux.SIGXFSZ), @@ -254,12 +248,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EFBIG } - if err := file.Allocate(t, mode, uint64(offset), uint64(length)); err != nil { - return 0, nil, err - } - - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - return 0, nil, nil + return 0, nil, file.Allocate(t, mode, uint64(offset), uint64(length)) } // Utime implements Linux syscall utime(2). diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 8bb763a47..19e175203 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -170,13 +170,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } } - if n != 0 { - // On Linux, inotify behavior is not very consistent with splice(2). We try - // our best to emulate Linux for very basic calls to splice, where for some - // reason, events are generated for output files, but not input files. - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } - // We can only pass a single file to handleIOError, so pick inFile arbitrarily. // This is used only for debugging purposes. return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "splice", outFile) @@ -256,8 +249,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo } if n != 0 { - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - // If a partial write is completed, the error is dropped. Log it here. if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { log.Debugf("tee completed a partial write with error: %v", err) @@ -449,9 +440,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } if total != 0 { - inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { // If a partial write is completed, the error is dropped. Log it here. log.Debugf("sendfile completed a partial write with error: %v", err) diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 5321ac80a..f612a71b2 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -161,6 +161,13 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou // DecRef decrements fd's reference count. func (fd *FileDescription) DecRef(ctx context.Context) { fd.FileDescriptionRefs.DecRef(func() { + // Generate inotify events. + ev := uint32(linux.IN_CLOSE_NOWRITE) + if fd.IsWritable() { + ev = linux.IN_CLOSE_WRITE + } + fd.Dentry().InotifyWithParent(ctx, ev, 0, PathEvent) + // Unregister fd from all epoll instances. fd.epollMu.Lock() epolls := fd.epolls @@ -448,16 +455,19 @@ type FileDescriptionImpl interface { RemoveXattr(ctx context.Context, name string) error // LockBSD tries to acquire a BSD-style advisory file lock. - LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error + LockBSD(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, block lock.Blocker) error // UnlockBSD releases a BSD-style advisory file lock. UnlockBSD(ctx context.Context, uid lock.UniqueID) error // LockPOSIX tries to acquire a POSIX-style advisory file lock. - LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, length uint64, whence int16, block lock.Blocker) error + LockPOSIX(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, r lock.LockRange, block lock.Blocker) error // UnlockPOSIX releases a POSIX-style advisory file lock. - UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, length uint64, whence int16) error + UnlockPOSIX(ctx context.Context, uid lock.UniqueID, ComputeLockRange lock.LockRange) error + + // TestPOSIX returns information about whether the specified lock can be held, in the style of the F_GETLK fcntl. + TestPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, r lock.LockRange) (linux.Flock, error) } // Dirent holds the information contained in struct linux_dirent64. @@ -556,7 +566,11 @@ func (fd *FileDescription) Allocate(ctx context.Context, mode, offset, length ui if !fd.IsWritable() { return syserror.EBADF } - return fd.impl.Allocate(ctx, mode, offset, length) + if err := fd.impl.Allocate(ctx, mode, offset, length); err != nil { + return err + } + fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent) + return nil } // Readiness implements waiter.Waitable.Readiness. @@ -592,6 +606,9 @@ func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of } start := fsmetric.StartReadWait() n, err := fd.impl.PRead(ctx, dst, offset, opts) + if n > 0 { + fd.Dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, PathEvent) + } fsmetric.Reads.Increment() fsmetric.FinishReadWait(fsmetric.ReadWait, start) return n, err @@ -604,6 +621,9 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt } start := fsmetric.StartReadWait() n, err := fd.impl.Read(ctx, dst, opts) + if n > 0 { + fd.Dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, PathEvent) + } fsmetric.Reads.Increment() fsmetric.FinishReadWait(fsmetric.ReadWait, start) return n, err @@ -619,7 +639,11 @@ func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, o if !fd.writable { return 0, syserror.EBADF } - return fd.impl.PWrite(ctx, src, offset, opts) + n, err := fd.impl.PWrite(ctx, src, offset, opts) + if n > 0 { + fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent) + } + return n, err } // Write is similar to PWrite, but does not specify an offset. @@ -627,7 +651,11 @@ func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, op if !fd.writable { return 0, syserror.EBADF } - return fd.impl.Write(ctx, src, opts) + n, err := fd.impl.Write(ctx, src, opts) + if n > 0 { + fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent) + } + return n, err } // IterDirents invokes cb on each entry in the directory represented by fd. If @@ -791,9 +819,9 @@ func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) e } // LockBSD tries to acquire a BSD-style advisory file lock. -func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, blocker lock.Blocker) error { +func (fd *FileDescription) LockBSD(ctx context.Context, ownerPID int32, lockType lock.LockType, blocker lock.Blocker) error { atomic.StoreUint32(&fd.usedLockBSD, 1) - return fd.impl.LockBSD(ctx, fd, lockType, blocker) + return fd.impl.LockBSD(ctx, fd, ownerPID, lockType, blocker) } // UnlockBSD releases a BSD-style advisory file lock. @@ -802,13 +830,45 @@ func (fd *FileDescription) UnlockBSD(ctx context.Context) error { } // LockPOSIX locks a POSIX-style file range lock. -func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, end uint64, whence int16, block lock.Blocker) error { - return fd.impl.LockPOSIX(ctx, uid, t, start, end, whence, block) +func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, r lock.LockRange, block lock.Blocker) error { + return fd.impl.LockPOSIX(ctx, uid, ownerPID, t, r, block) } // UnlockPOSIX unlocks a POSIX-style file range lock. -func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error { - return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence) +func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, r lock.LockRange) error { + return fd.impl.UnlockPOSIX(ctx, uid, r) +} + +// TestPOSIX returns information about whether the specified lock can be held. +func (fd *FileDescription) TestPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, r lock.LockRange) (linux.Flock, error) { + return fd.impl.TestPOSIX(ctx, uid, t, r) +} + +// ComputeLockRange computes the range of a file lock based on the given values. +func (fd *FileDescription) ComputeLockRange(ctx context.Context, start uint64, length uint64, whence int16) (lock.LockRange, error) { + var off int64 + switch whence { + case linux.SEEK_SET: + off = 0 + case linux.SEEK_CUR: + // Note that Linux does not hold any mutexes while retrieving the file + // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk. + curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR) + if err != nil { + return lock.LockRange{}, err + } + off = curOff + case linux.SEEK_END: + stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE}) + if err != nil { + return lock.LockRange{}, err + } + off = int64(stat.Size) + default: + return lock.LockRange{}, syserror.EINVAL + } + + return lock.ComputeRange(int64(start), int64(length), off) } // A FileAsync sends signals to its owner when w is ready for IO. This is only diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 48ca9de44..eb7d2fd3b 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -419,8 +419,8 @@ func (fd *LockFD) Locks() *FileLocks { } // LockBSD implements vfs.FileDescriptionImpl.LockBSD. -func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { - return fd.locks.LockBSD(uid, t, block) +func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { + return fd.locks.LockBSD(ctx, uid, ownerPID, t, block) } // UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. @@ -429,6 +429,21 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { return nil } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { + return fd.locks.LockPOSIX(ctx, uid, ownerPID, t, r, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + return fd.locks.UnlockPOSIX(ctx, uid, r) +} + +// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return fd.locks.TestPOSIX(ctx, uid, t, r) +} + // NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface // returning ENOLCK. // @@ -436,7 +451,7 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { type NoLockFD struct{} // LockBSD implements vfs.FileDescriptionImpl.LockBSD. -func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { +func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return syserror.ENOLCK } @@ -446,11 +461,16 @@ func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { +func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { return syserror.ENOLCK } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { +func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { return syserror.ENOLCK } + +// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +func (NoLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return linux.Flock{}, syserror.ENOLCK +} diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 1ff202f2a..cbe4d8c2d 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -39,8 +39,8 @@ type FileLocks struct { } // LockBSD tries to acquire a BSD-style lock on the entire file. -func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { - if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) { +func (fl *FileLocks) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerID int32, t fslock.LockType, block fslock.Blocker) error { + if fl.bsd.LockRegion(uid, ownerID, t, fslock.LockRange{0, fslock.LockEOF}, block) { return nil } @@ -61,12 +61,8 @@ func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) { } // LockPOSIX tries to acquire a POSIX-style lock on a file region. -func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - rng, err := computeRange(ctx, fd, start, length, whence) - if err != nil { - return err - } - if fl.posix.LockRegion(uid, t, rng, block) { +func (fl *FileLocks) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { + if fl.posix.LockRegion(uid, ownerPID, t, r, block) { return nil } @@ -82,37 +78,12 @@ func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fsl // // This operation is always successful, even if there did not exist a lock on // the requested region held by uid in the first place. -func (fl *FileLocks) UnlockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, start, length uint64, whence int16) error { - rng, err := computeRange(ctx, fd, start, length, whence) - if err != nil { - return err - } - fl.posix.UnlockRegion(uid, rng) +func (fl *FileLocks) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + fl.posix.UnlockRegion(uid, r) return nil } -func computeRange(ctx context.Context, fd *FileDescription, start uint64, length uint64, whence int16) (fslock.LockRange, error) { - var off int64 - switch whence { - case linux.SEEK_SET: - off = 0 - case linux.SEEK_CUR: - // Note that Linux does not hold any mutexes while retrieving the file - // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk. - curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR) - if err != nil { - return fslock.LockRange{}, err - } - off = curOff - case linux.SEEK_END: - stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE}) - if err != nil { - return fslock.LockRange{}, err - } - off = int64(stat.Size) - default: - return fslock.LockRange{}, syserror.EINVAL - } - - return fslock.ComputeRange(int64(start), int64(length), off) +// TestPOSIX returns information about whether the specified lock can be held, in the style of the F_GETLK fcntl. +func (fl *FileLocks) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return fl.posix.TestRegion(ctx, uid, t, r), nil } diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index fdeec12d3..7c7495c30 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -16,6 +16,7 @@ package gonet import ( + "bytes" "context" "errors" "io" @@ -354,8 +355,6 @@ func (c *TCPConn) Write(b []byte) (int, error) { default: } - v := buffer.NewViewFromBytes(b) - // We must handle two soft failure conditions simultaneously: // 1. Write may write nothing and return tcpip.ErrWouldBlock. // If this happens, we need to register for notifications if we have @@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) { // There is no guarantee that all of the condition #1s will occur before // all of the condition #2s or visa-versa. var ( - err *tcpip.Error - nbytes int - reg bool - notifyCh chan struct{} + r bytes.Reader + nbytes int + entry waiter.Entry + ch <-chan struct{} ) - for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { - if err == tcpip.ErrWouldBlock { - if !reg { - // Only register once. - reg = true - - // Create wait queue entry that notifies a channel. - var waitEntry waiter.Entry - waitEntry, notifyCh = waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) + for nbytes != len(b) { + r.Reset(b[nbytes:]) + n, err := c.ep.Write(&r, tcpip.WriteOptions{}) + nbytes += int(n) + switch err { + case nil: + case tcpip.ErrWouldBlock: + if ch == nil { + entry, ch = waiter.NewChannelEntry(nil) + + c.wq.EventRegister(&entry, waiter.EventOut) + defer c.wq.EventUnregister(&entry) } else { // Don't wait immediately after registration in case more data // became available between when we last checked and when we setup @@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) { select { case <-deadline: return nbytes, c.newOpError("write", &timeoutError{}) - case <-notifyCh: + case <-ch: + continue } } + default: + return nbytes, c.newOpError("write", errors.New(err.String())) } - - var n int64 - n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } - - if err == nil { - return nbytes, nil } - - return nbytes, c.newOpError("write", errors.New(err.String())) + return nbytes, nil } // Close implements net.Conn.Close. @@ -644,16 +637,18 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // If we're being called by Write, there is no addr - wopts := tcpip.WriteOptions{} + writeOptions := tcpip.WriteOptions{} if addr != nil { ua := addr.(*net.UDPAddr) - wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} + writeOptions.To = &tcpip.FullAddress{ + Addr: tcpip.Address(ua.IP), + Port: uint16(ua.Port), + } } - v := buffer.NewView(len(b)) - copy(v, b) - - n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) + var r bytes.Reader + r.Reset(b) + n, err := c.ep.Write(&r, writeOptions) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -666,7 +661,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { case <-notifyCh: } - n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + n, err = c.ep.Write(&r, writeOptions) if err != tcpip.ErrWouldBlock { break } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 0ac2000ca..07b4393a4 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -234,7 +234,7 @@ func IPv4RouterAlert() NetworkChecker { for { opt, done, err := iterator.Next() if err != nil { - t.Fatalf("error acquiring next IPv4 option %s", err) + t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer) } if done { break diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e6103f4bc..48ca60319 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package header import ( "encoding/binary" - "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -481,15 +480,13 @@ const ( IPv4OptionLengthOffset = 1 ) -// Potential errors when parsing generic IP options. -var ( - ErrIPv4OptZeroLength = errors.New("zero length IP option") - ErrIPv4OptDuplicate = errors.New("duplicate IP option") - ErrIPv4OptInvalid = errors.New("invalid IP option") - ErrIPv4OptMalformed = errors.New("malformed IP option") - ErrIPv4OptionTruncated = errors.New("truncated IP option") - ErrIPv4OptionAddress = errors.New("bad IP option address") -) +// IPv4OptParameterProblem indicates that a Parameter Problem message +// should be generated, and gives the offset in the current entity +// that should be used in that packet. +type IPv4OptParameterProblem struct { + Pointer uint8 + NeedICMP bool +} // IPv4Option is an interface representing various option types. type IPv4Option interface { @@ -583,8 +580,9 @@ func (i *IPv4OptionIterator) Finalize() IPv4Options { // It returns // - A slice of bytes holding the next option or nil if there is error. // - A boolean which is true if parsing of all the options is complete. -// - An error which is non-nil if an error condition was encountered. -func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { +// Undefined in the case of error. +// - An error indication which is non-nil if an error condition was found. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, *IPv4OptParameterProblem) { // The opts slice gets shorter as we process the options. When we have no // bytes left we are done. if len(i.options) == 0 { @@ -606,24 +604,22 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { // There are no more single byte options defined. All the rest have a length // field so we need to sanity check it. if len(i.options) == 1 { - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } optLen := i.options[IPv4OptionLengthOffset] - if optLen == 0 { - i.ErrCursor++ - return nil, true, ErrIPv4OptZeroLength - } + if optLen <= IPv4OptionLengthOffset || optLen > uint8(len(i.options)) { + // The actual error is in the length (2nd byte of the option) but we + // return the start of the option for compatibility with Linux. - if optLen == 1 { - i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed - } - - if optLen > uint8(len(i.options)) { - i.ErrCursor++ - return nil, true, ErrIPv4OptionTruncated + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } optionBody := i.options[:optLen] @@ -635,7 +631,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { case IPv4OptionTimestampType: if optLen < IPv4OptionTimestampHdrLength { i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } retval := IPv4OptionTimestamp(optionBody) return &retval, false, nil @@ -643,7 +642,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { case IPv4OptionRecordRouteType: if optLen < IPv4OptionRecordRouteHdrLength { i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } retval := IPv4OptionRecordRoute(optionBody) return &retval, false, nil diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 10072eac1..ae1394ebf 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -35,7 +35,6 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 90da22d34..e2985cb84 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -465,67 +464,85 @@ var capLengthTestCases = []struct { config: []int{1, 2, 3}, n: 3, wantUsed: 2, - wantLengths: []int{1, 2, 3}, + wantLengths: []int{1, 2}, }, } -func TestReadVDispatcherCapLength(t *testing.T) { +func TestIovecBuffer(t *testing.T) { for _, c := range capLengthTestCases { - // fd does not matter for this test. - d := readVDispatcher{fd: -1, e: &endpoint{}} - d.views = make([]buffer.View, len(c.config)) - d.iovecs = make([]syscall.Iovec, len(c.config)) - d.allocateViews(c.config) - - used := d.capViews(c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views)) - for i, v := range d.views { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - } -} + t.Run(c.comment, func(t *testing.T) { + b := newIovecBuffer(c.config, false /* skipsVnetHdr */) -func TestRecvMMsgDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - d := recvMMsgDispatcher{ - fd: -1, // fd does not matter for this test. - e: &endpoint{}, - views: make([][]buffer.View, 1), - iovecs: make([][]syscall.Iovec, 1), - msgHdrs: make([]rawfile.MMsgHdr, 1), - } + // Test initial allocation. + iovecs := b.nextIovecs() + if got, want := len(iovecs), len(c.config); got != want { + t.Fatalf("len(iovecs) = %d, want %d", got, want) + } - for i := range d.views { - d.views[i] = make([]buffer.View, len(c.config)) - } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, len(c.config)) - } - for k, msgHdr := range d.msgHdrs { - msgHdr.Msg.Iov = &d.iovecs[k][0] - msgHdr.Msg.Iovlen = uint64(len(c.config)) - } + // Make a copy as iovecs points to internal slice. We will need this state + // later. + oldIovecs := append([]syscall.Iovec(nil), iovecs...) - d.allocateViews(c.config) + // Test the views that get pulled. + vv := b.pullViews(c.n) + var lengths []int + for _, v := range vv.Views() { + lengths = append(lengths, len(v)) + } + if !reflect.DeepEqual(lengths, c.wantLengths) { + t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths) + } - used := d.capViews(0, c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views[0])) - for i, v := range d.views[0] { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } + // Test that new views get reallocated. + for i, newIov := range b.nextIovecs() { + if i < c.wantUsed { + if newIov.Base == oldIovecs[i].Base { + t.Errorf("b.views[%d] should have been reallocated", i) + } + } else { + if newIov.Base != oldIovecs[i].Base { + t.Errorf("b.views[%d] should not have been reallocated", i) + } + } + } + }) + } +} +func TestIovecBufferSkipVnetHdr(t *testing.T) { + for _, test := range []struct { + desc string + readN int + wantLen int + }{ + { + desc: "nothing read", + readN: 0, + wantLen: 0, + }, + { + desc: "smaller than vnet header", + readN: virtioNetHdrSize - 1, + wantLen: 0, + }, + { + desc: "header skipped", + readN: virtioNetHdrSize + 100, + wantLen: 100, + }, + } { + t.Run(test.desc, func(t *testing.T) { + b := newIovecBuffer([]int{10, 20, 50, 50}, true) + // Pretend a read happend. + b.nextIovecs() + vv := b.pullViews(test.readN) + if got, want := vv.Size(), test.wantLen; got != want { + t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want) + } + if got, want := len(vv.ToOwnedView()), test.wantLen; got != want { + t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want) + } + }) } } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 8c3ca86d6..edab110b5 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -29,92 +29,124 @@ import ( // BufConfig defines the shape of the vectorised view used to read packets from the NIC. var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} -// readVDispatcher uses readv() system call to read inbound packets and -// dispatches them. -type readVDispatcher struct { - // fd is the file descriptor used to send and receive packets. - fd int - - // e is the endpoint this dispatcher is attached to. - e *endpoint - +type iovecBuffer struct { // views are the actual buffers that hold the packet contents. views []buffer.View // iovecs are initialized with base pointers/len of the corresponding - // entries in the views defined above, except when GSO is enabled then - // the first iovec points to a buffer for the vnet header which is - // stripped before the views are passed up the stack for further + // entries in the views defined above, except when GSO is enabled + // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header + // which is stripped before the views are passed up the stack for further // processing. iovecs []syscall.Iovec + + // sizes is an array of buffer sizes for the underlying views. sizes is + // immutable. + sizes []int + + // skipsVnetHdr is true if virtioNetHdr is to skipped. + skipsVnetHdr bool } -func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { - d := &readVDispatcher{fd: fd, e: e} - d.views = make([]buffer.View, len(BufConfig)) - iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - iovLen++ +func newIovecBuffer(sizes []int, skipsVnetHdr bool) *iovecBuffer { + b := &iovecBuffer{ + views: make([]buffer.View, len(sizes)), + sizes: sizes, + skipsVnetHdr: skipsVnetHdr, } - d.iovecs = make([]syscall.Iovec, iovLen) - return d, nil + niov := len(b.views) + if b.skipsVnetHdr { + niov++ + } + b.iovecs = make([]syscall.Iovec, niov) + return b } -func (d *readVDispatcher) allocateViews(bufConfig []int) { - var vnetHdr [virtioNetHdrSize]byte +func (b *iovecBuffer) nextIovecs() []syscall.Iovec { vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if b.skipsVnetHdr { + var vnetHdr [virtioNetHdrSize]byte // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. - d.iovecs[0] = syscall.Iovec{ + b.iovecs[0] = syscall.Iovec{ Base: &vnetHdr[0], Len: uint64(virtioNetHdrSize), } vnetHdrOff++ } - for i := 0; i < len(bufConfig); i++ { - if d.views[i] != nil { + for i := range b.views { + if b.views[i] != nil { break } - b := buffer.NewView(bufConfig[i]) - d.views[i] = b - d.iovecs[i+vnetHdrOff] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), + v := buffer.NewView(b.sizes[i]) + b.views[i] = v + b.iovecs[i+vnetHdrOff] = syscall.Iovec{ + Base: &v[0], + Len: uint64(len(v)), } } + return b.iovecs } -func (d *readVDispatcher) capViews(n int, buffers []int) int { +func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView { + var views []buffer.View c := 0 - for i, s := range buffers { - c += s + if b.skipsVnetHdr { + c += virtioNetHdrSize if c >= n { - d.views[i].CapLength(s - (c - n)) - return i + 1 + // Nothing in the packet. + return buffer.NewVectorisedView(0, nil) + } + } + for i, v := range b.views { + c += len(v) + if c >= n { + b.views[i].CapLength(len(v) - (c - n)) + views = append([]buffer.View(nil), b.views[:i+1]...) + break } } - return len(buffers) + // Remove the first len(views) used views from the state. + for i := range views { + b.views[i] = nil + } + if b.skipsVnetHdr { + // Exclude the size of the vnet header. + n -= virtioNetHdrSize + } + return buffer.NewVectorisedView(n, views) +} + +// readVDispatcher uses readv() system call to read inbound packets and +// dispatches them. +type readVDispatcher struct { + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // buf is the iovec buffer that contains the packet contents. + buf *iovecBuffer +} + +func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + d := &readVDispatcher{fd: fd, e: e} + skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) + return d, nil } // dispatch reads one packet from the file descriptor and dispatches it. func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { - d.allocateViews(BufConfig) - - n, err := rawfile.BlockingReadv(d.fd, d.iovecs) + n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs()) if n == 0 || err != nil { return false, err } - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // Skip virtioNetHdr which is added before each packet, it - // isn't used and it isn't in a view. - n -= virtioNetHdrSize - } - used := d.capViews(n, BufConfig) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + Data: d.buf.pullViews(n), }) var ( @@ -133,7 +165,12 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(d.views[0]) { + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data.PullUp(1) + if !ok { + return true, nil + } + switch header.IPVersion(h) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: @@ -145,11 +182,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) - // Prepare e.views for another packet: release used views. - for i := 0; i < used; i++ { - d.views[i] = nil - } - return true, nil } @@ -162,15 +194,8 @@ type recvMMsgDispatcher struct { // e is the endpoint this dispatcher is attached to. e *endpoint - // views is an array of array of buffers that contain packet contents. - views [][]buffer.View - - // iovecs is an array of array of iovec records where each iovec base - // pointer and length are initialzed to the corresponding view above, - // except when GSO is enabled then the first iovec in each array of - // iovecs points to a buffer for the vnet header which is stripped - // before the views are passed up the stack for further processing. - iovecs [][]syscall.Iovec + // bufs is an array of iovec buffers that contain packet contents. + bufs []*iovecBuffer // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to // reference an array of iovecs in the iovecs field defined above. This @@ -187,74 +212,32 @@ const ( func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &recvMMsgDispatcher{ - fd: fd, - e: e, - } - d.views = make([][]buffer.View, MaxMsgsPerRecv) - for i := range d.views { - d.views[i] = make([]buffer.View, len(BufConfig)) - } - d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv) - iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // virtioNetHdr is prepended before each packet. - iovLen++ + fd: fd, + e: e, + bufs: make([]*iovecBuffer, MaxMsgsPerRecv), + msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, iovLen) - } - d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv) - for i := range d.msgHdrs { - d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0] - d.msgHdrs[i].Msg.Iovlen = uint64(iovLen) + skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + for i := range d.bufs { + d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) } return d, nil } -func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int { - c := 0 - for i, s := range buffers { - c += s - if c >= n { - d.views[k][i].CapLength(s - (c - n)) - return i + 1 - } - } - return len(buffers) -} - -func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) { - for k := 0; k < len(d.views); k++ { - var vnetHdr [virtioNetHdrSize]byte - vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // The kernel adds virtioNetHdr before each packet, but - // we don't use it, so so we allocate a buffer for it, - // add it in iovecs but don't add it in a view. - d.iovecs[k][0] = syscall.Iovec{ - Base: &vnetHdr[0], - Len: uint64(virtioNetHdrSize), - } - vnetHdrOff++ - } - for i := 0; i < len(bufConfig); i++ { - if d.views[k][i] != nil { - break - } - b := buffer.NewView(bufConfig[i]) - d.views[k][i] = b - d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), - } - } - } -} - // recvMMsgDispatch reads more than one packet at a time from the file // descriptor and dispatches it. func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { - d.allocateViews(BufConfig) + // Fill message headers. + for k := range d.msgHdrs { + if d.msgHdrs[k].Msg.Iovlen > 0 { + break + } + iovecs := d.bufs[k].nextIovecs() + iovLen := len(iovecs) + d.msgHdrs[k].Len = 0 + d.msgHdrs[k].Msg.Iov = &iovecs[0] + d.msgHdrs[k].Msg.Iovlen = uint64(iovLen) + } nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs) if err != nil { @@ -263,15 +246,14 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { // Process each of received packets. for k := 0; k < nMsgs; k++ { n := int(d.msgHdrs[k].Len) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - n -= virtioNetHdrSize - } - used := d.capViews(k, int(n), BufConfig) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + Data: d.bufs[k].pullViews(n), }) + // Mark that this iovec has been processed. + d.msgHdrs[k].Msg.Iovlen = 0 + var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress @@ -288,26 +270,24 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(d.views[k][0]) { + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data.PullUp(1) + if !ok { + // Skip this packet. + continue + } + switch header.IPVersion(h) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: p = header.IPv6ProtocolNumber default: - return true, nil + // Skip this packet. + continue } } d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) - - // Prepare e.views for another packet: release used views. - for i := 0; i < used; i++ { - d.views[k][i] = nil - } - } - - for k := 0; k < nMsgs; k++ { - d.msgHdrs[k].Len = 0 } return true, nil diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index d6e83a414..36aa9055c 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -45,12 +45,7 @@ type Endpoint struct { linkAddr tcpip.LinkAddress } -// WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - if !e.linked.IsAttached() { - return nil - } - +func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) { // Note that the local address from the perspective of this endpoint is the // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this @@ -70,16 +65,33 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw // // TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send // and receive queues. - go e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) + go func() { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + } + }() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.linked.IsAttached() { + var pkts stack.PacketBufferList + pkts.PushBack(pkt) + e.deliverPackets(r, proto, pkts) + } return nil } // WritePackets implements stack.LinkEndpoint. -func (*Endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - panic("not implemented") +func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if e.linked.IsAttached() { + e.deliverPackets(r, proto, pkts) + } + + return pkts.Len(), nil } // Attach implements stack.LinkEndpoint. diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 87035b034..03efba606 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -165,12 +165,15 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip } // WritePackets implements stack.LinkEndpoint.WritePackets. +// +// Being a batch API, each packet in pkts should have the following +// fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { enqueued := 0 for pkt := pkts.Front(); pkt != nil; { - pkt.EgressRoute = r - pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] nxt := pkt.Next() if !d.q.enqueue(pkt) { diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 8a6bcfc2c..c7ab876bf 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -4,9 +4,13 @@ package(licenses = ["notice"]) go_library( name = "arp", - srcs = ["arp.go"], + srcs = [ + "arp.go", + "stats.go", + ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -33,3 +37,15 @@ go_test( "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) + +go_test( + name = "stats_test", + size = "small", + srcs = ["stats_test.go"], + library = ":arp", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9255a4f6a..6bc8c5c02 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -19,8 +19,10 @@ package arp import ( "fmt" + "reflect" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,6 +52,7 @@ type endpoint struct { nic stack.NetworkInterface linkAddrCache stack.LinkAddressCache nud stack.NUDHandler + stats sharedStats } func (e *endpoint) Enable() *tcpip.Error { @@ -98,7 +101,9 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (*endpoint) Close() {} +func (e *endpoint) Close() { + e.protocol.forgetEndpoint(e.nic.ID()) +} func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported @@ -119,27 +124,27 @@ func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *t } func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats().ARP - stats.PacketsReceived.Increment() + stats := e.stats.arp + stats.packetsReceived.Increment() if !e.isEnabled() { - stats.DisabledPacketsReceived.Increment() + stats.disabledPacketsReceived.Increment() return } h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { - stats.MalformedPacketsReceived.Increment() + stats.malformedPacketsReceived.Increment() return } switch h.Op() { case header.ARPRequest: - stats.RequestsReceived.Increment() + stats.requestsReceived.Increment() localAddr := tcpip.Address(h.ProtocolAddressTarget()) if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { - stats.RequestsReceivedUnknownTargetAddress.Increment() + stats.requestsReceivedUnknownTargetAddress.Increment() return // we have no useful answer, ignore the request } @@ -180,13 +185,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Send the packet to the (new) target hardware address on the same // hardware on which the request was received. if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil { - stats.OutgoingRepliesDropped.Increment() + stats.outgoingRepliesDropped.Increment() } else { - stats.OutgoingRepliesSent.Increment() + stats.outgoingRepliesSent.Increment() } case header.ARPReply: - stats.RepliesReceived.Increment() + stats.repliesReceived.Increment() addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) @@ -212,21 +217,23 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Stats implements stack.NetworkEndpoint. func (e *endpoint) Stats() stack.NetworkEndpointStats { - // TODO(gvisor.dev/issues/4963): Record statistics for ARP. - return &Stats{} + return &e.stats.localStats } -var _ stack.NetworkEndpointStats = (*Stats)(nil) - -// Stats holds ARP statistics. -type Stats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*Stats) IsNetworkEndpointStats() {} +var _ stack.NetworkProtocol = (*protocol)(nil) +var _ stack.LinkAddressResolver = (*protocol)(nil) // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack + + mu struct { + sync.RWMutex + + // eps is keyed by NICID to allow protocol methods to retrieve the correct + // endpoint depending on the NIC. + eps map[tcpip.NICID]*endpoint + } } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -244,9 +251,25 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L linkAddrCache: linkAddrCache, nud: nud, } + + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + + stackStats := p.stack.Stats() + e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP) + + p.mu.Lock() + p.mu.eps[nic.ID()] = e + p.mu.Unlock() + return e } +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, nicID) +} + // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber @@ -254,28 +277,35 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { - stats := p.stack.Stats().ARP + nicID := nic.ID() + + p.mu.Lock() + netEP, ok := p.mu.eps[nicID] + p.mu.Unlock() + if !ok { + return tcpip.ErrNotConnected + } + + stats := netEP.stats.arp if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } - nicID := nic.ID() if len(localAddr) == 0 { - addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) - if err != nil { - stats.OutgoingRequestInterfaceHasNoLocalAddressErrors.Increment() - return err + addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if !ok { + return tcpip.ErrUnknownNICID } if len(addr.Address) == 0 { - stats.OutgoingRequestNetworkUnreachableErrors.Increment() + stats.outgoingRequestInterfaceHasNoLocalAddressErrors.Increment() return tcpip.ErrNetworkUnreachable } localAddr = addr.Address } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { - stats.OutgoingRequestBadLocalAddressErrors.Increment() + stats.outgoingRequestBadLocalAddressErrors.Increment() return tcpip.ErrBadLocalAddress } @@ -296,10 +326,10 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { - stats.OutgoingRequestsDropped.Increment() + stats.outgoingRequestsDropped.Increment() return err } - stats.OutgoingRequestsSent.Increment() + stats.outgoingRequestsSent.Increment() return nil } @@ -337,5 +367,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu // NewProtocol returns an ARP network protocol. func NewProtocol(s *stack.Stack) stack.NetworkProtocol { - return &protocol{stack: s} + return &protocol{ + stack: s, + mu: struct { + sync.RWMutex + eps map[tcpip.NICID]*endpoint + }{eps: make(map[tcpip.NICID]*endpoint)}, + } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 8d6ee37fa..6b23c0079 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -585,104 +585,122 @@ func TestLinkAddressRequest(t *testing.T) { testAddr := tcpip.Address([]byte{1, 2, 3, 4}) tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - - linkErr *tcpip.Error - expectedErr *tcpip.Error - expectedLocalAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - expectedRequestsSent uint64 - expectedRequestBadLocalAddressErrors uint64 - expectedRequestNetworkUnreachableErrors uint64 - expectedRequestDroppedErrors uint64 + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + linkErr *tcpip.Error + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress + expectedRequestsSent uint64 + expectedRequestBadLocalAddressErrors uint64 + expectedRequestInterfaceHasNoLocalAddressErrors uint64 + expectedRequestDroppedErrors uint64 }{ { - name: "Unicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast with unspecified source", + nicAddr: stackAddr, + localAddr: "", + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast with unspecified source", + nicAddr: stackAddr, + localAddr: "", + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrBadLocalAddress, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast with unassigned address", + nicAddr: stackAddr, + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: "", - expectedErr: tcpip.ErrBadLocalAddress, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast with unassigned address", + nicAddr: stackAddr, + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with no local address available", - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrNetworkUnreachable, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 1, + name: "Unicast with no local address available", + nicAddr: "", + localAddr: "", + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 1, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with no local address available", - remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 1, + name: "Multicast with no local address available", + nicAddr: "", + localAddr: "", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 1, + expectedRequestDroppedErrors: 0, }, { - name: "Link error", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - linkErr: tcpip.ErrInvalidEndpointState, - expectedErr: tcpip.ErrInvalidEndpointState, - expectedRequestDroppedErrors: 1, + name: "Link error", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + linkErr: tcpip.ErrInvalidEndpointState, + expectedErr: tcpip.ErrInvalidEndpointState, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 1, }, } @@ -721,12 +739,12 @@ func TestLinkAddressRequest(t *testing.T) { if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent) } + if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors) + } if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors { t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors) } - if got := s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value(); got != test.expectedRequestNetworkUnreachableErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value() = %d, want = %d", got, test.expectedRequestNetworkUnreachableErrors) - } if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors { t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors) } @@ -774,11 +792,7 @@ func TestLinkAddressRequestWithoutNIC(t *testing.T) { t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") } - if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrUnknownNICID { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrUnknownNICID) - } - - if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != 1 { - t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = 1", got) + if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrNotConnected { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrNotConnected) } } diff --git a/pkg/tcpip/network/arp/stats.go b/pkg/tcpip/network/arp/stats.go new file mode 100644 index 000000000..6d7194c6c --- /dev/null +++ b/pkg/tcpip/network/arp/stats.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arp + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to ARP. +type Stats struct { + // ARP holds ARP statistics. + ARP tcpip.ARPStats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*Stats) IsNetworkEndpointStats() {} + +type sharedStats struct { + localStats Stats + arp multiCounterARPStats +} + +// LINT.IfChange(multiCounterARPStats) + +type multiCounterARPStats struct { + packetsReceived tcpip.MultiCounterStat + disabledPacketsReceived tcpip.MultiCounterStat + malformedPacketsReceived tcpip.MultiCounterStat + requestsReceived tcpip.MultiCounterStat + requestsReceivedUnknownTargetAddress tcpip.MultiCounterStat + outgoingRequestInterfaceHasNoLocalAddressErrors tcpip.MultiCounterStat + outgoingRequestBadLocalAddressErrors tcpip.MultiCounterStat + outgoingRequestsDropped tcpip.MultiCounterStat + outgoingRequestsSent tcpip.MultiCounterStat + repliesReceived tcpip.MultiCounterStat + outgoingRepliesDropped tcpip.MultiCounterStat + outgoingRepliesSent tcpip.MultiCounterStat +} + +func (m *multiCounterARPStats) init(a, b *tcpip.ARPStats) { + m.packetsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.disabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) + m.malformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived) + m.requestsReceived.Init(a.RequestsReceived, b.RequestsReceived) + m.requestsReceivedUnknownTargetAddress.Init(a.RequestsReceivedUnknownTargetAddress, b.RequestsReceivedUnknownTargetAddress) + m.outgoingRequestInterfaceHasNoLocalAddressErrors.Init(a.OutgoingRequestInterfaceHasNoLocalAddressErrors, b.OutgoingRequestInterfaceHasNoLocalAddressErrors) + m.outgoingRequestBadLocalAddressErrors.Init(a.OutgoingRequestBadLocalAddressErrors, b.OutgoingRequestBadLocalAddressErrors) + m.outgoingRequestsDropped.Init(a.OutgoingRequestsDropped, b.OutgoingRequestsDropped) + m.outgoingRequestsSent.Init(a.OutgoingRequestsSent, b.OutgoingRequestsSent) + m.repliesReceived.Init(a.RepliesReceived, b.RepliesReceived) + m.outgoingRepliesDropped.Init(a.OutgoingRepliesDropped, b.OutgoingRepliesDropped) + m.outgoingRepliesSent.Init(a.OutgoingRepliesSent, b.OutgoingRepliesSent) +} + +// LINT.ThenChange(../../tcpip.go:ARPStats) diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go new file mode 100644 index 000000000..036fdf739 --- /dev/null +++ b/pkg/tcpip/network/arp/stats_test.go @@ -0,0 +1,93 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arp + +import ( + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.NetworkInterface + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func knownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + nic := testInterface{nicID: 1} + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) + } + + ep.Close() + + proto.mu.Lock() + _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEndpointAfterClose { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } +} + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // expected to be bound by a MultiCounterStat. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 3005973d7..2a6ec19dc 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -235,14 +235,14 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } -func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) { +func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) { t.Helper() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) - e := channel.New(0, 1280, "") + e := channel.New(1, mtu, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -263,7 +263,7 @@ func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpo func buildDummyStack(t *testing.T) *stack.Stack { t.Helper() - s, _ := buildDummyStackWithLinkEndpoint(t) + s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) return s } @@ -416,7 +416,7 @@ func TestSourceAddressValidation(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t) + s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) test.rxICMP(e, test.srcAddress) var wantValid uint64 @@ -1490,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, }) - e := channel.New(1, 1280, "") + e := channel.New(1, header.IPv6MinimumMTU, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -1526,3 +1526,246 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) } } + +// Test that the included data in an ICMP error packet conforms to the +// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3 +func TestICMPInclusionSize(t *testing.T) { + const ( + replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize + replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize + targetSize4 = header.IPv4MinimumProcessableDatagramSize + targetSize6 = header.IPv6MinimumMTU + // A protocol number that will cause an error response. + reservedProtocol = 254 + ) + + // IPv4 function to create a IP packet and send it to the stack. + // The packet should generate an error response. We can do that by using an + // unknown transport protocol (254). + rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { + totalLen := header.IPv4MinimumSize + len(payload) + hdr := buffer.NewPrependable(header.IPv4MinimumSize) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: reservedProtocol, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + vv := hdr.View().ToVectorisedView() + vv.AppendView(buffer.View(payload)) + // Take a copy before InjectInbound takes ownership of vv + // as vv may be changed during the call. + v := vv.ToView() + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + return v + } + + // IPv6 function to create a packet and send it to the stack. + // The packet should be errant in a way that causes the stack to send an + // ICMP error response and have enough data to allow the testing of the + // inclusion of the errant packet. Use `unknown next header' to generate + // the error. + rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload)), + TransportProtocol: reservedProtocol, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, + }) + vv := hdr.View().ToVectorisedView() + vv.AppendView(buffer.View(payload)) + // Take a copy before InjectInbound takes ownership of vv + // as vv may be changed during the call. + v := vv.ToView() + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + return v + } + + v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { + // We already know the entire packet is the right size so we can use its + // length to calculate the right payload size to check. + expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize + checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()), + checker.SrcAddr(localIPv4Addr), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4ProtoUnreachable), + checker.ICMPv4Payload(payload[:expectedPayloadLength]), + ), + ) + } + + v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { + // We already know the entire packet is the right size so we can use its + // length to calculate the right payload size to check. + expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize + checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()), + checker.SrcAddr(localIPv6Addr), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6ParamProblem), + checker.ICMPv6Code(header.ICMPv6UnknownHeader), + checker.ICMPv6Payload(payload[:expectedPayloadLength]), + ), + ) + } + tests := []struct { + name string + srcAddress tcpip.Address + injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View + checker func(*testing.T, *stack.PacketBuffer, buffer.View) + payloadLength int // Not including IP header. + linkMTU uint32 // Largest IP packet that the link can send as payload. + replyLength int // Total size of IP/ICMP packet expected back. + }{ + { + name: "IPv4 exact match", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 - replyHeaderLength4, + linkMTU: targetSize4, + replyLength: targetSize4, + }, + { + name: "IPv4 larger MTU", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4, + linkMTU: targetSize4 + 1000, + replyLength: targetSize4, + }, + { + name: "IPv4 smaller MTU", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4, + linkMTU: targetSize4 - 50, + replyLength: targetSize4 - 50, + }, + { + name: "IPv4 payload exceeds", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 + 10, + linkMTU: targetSize4, + replyLength: targetSize4, + }, + { + name: "IPv4 1 byte less", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 - replyHeaderLength4 - 1, + linkMTU: targetSize4, + replyLength: targetSize4 - 1, + }, + { + name: "IPv4 No payload", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: 0, + linkMTU: targetSize4, + replyLength: replyHeaderLength4, + }, + { + name: "IPv6 exact match", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6 - replyHeaderLength6, + linkMTU: targetSize6, + replyLength: targetSize6, + }, + { + name: "IPv6 larger MTU", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6, + linkMTU: targetSize6 + 400, + replyLength: targetSize6, + }, + // NB. No "smaller MTU" test here as less than 1280 is not permitted + // in IPv6. + { + name: "IPv6 payload exceeds", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6, + linkMTU: targetSize6, + replyLength: targetSize6, + }, + { + name: "IPv6 1 byte less", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6 - replyHeaderLength6 - 1, + linkMTU: targetSize6, + replyLength: targetSize6 - 1, + }, + { + name: "IPv6 no payload", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: 0, + linkMTU: targetSize6, + replyLength: replyHeaderLength6, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU) + // Allocate and initialize the payload view. + payload := buffer.NewView(test.payloadLength) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + // Default routes for IPv4&6 so ICMP can find a route to the remote + // node when attempting to send the ICMP error Reply. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + v := test.injector(e, test.srcAddress, payload) + pkt, ok := e.Read() + if !ok { + t.Fatal("expected a packet to be written") + } + if got, want := pkt.Pkt.Size(), test.replyLength; got != want { + t.Fatalf("got %d bytes of icmp error packet, want %d", got, want) + } + test.checker(t, pkt.Pkt, v) + }) + } +} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3f60de749..6bb97c46a 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ package ipv4 import ( - "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -105,17 +104,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - aux, tmp, err := e.processIPOptions(pkt, opts, op) - if err != nil { - switch { - case - errors.Is(err, header.ErrIPv4OptDuplicate), - errors.Is(err, errIPv4RecordRouteOptInvalidLength), - errors.Is(err, errIPv4RecordRouteOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + tmp, optProblem := e.processIPOptions(pkt, opts, op) + if optProblem != nil { + if optProblem.NeedICMP { + _ = e.protocol.returnError(&icmpReasonParamProblem{ + pointer: optProblem.Pointer, + }, pkt) e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } @@ -442,13 +436,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi // systems implement the RFC 1812 definition and not the original // requirement. We treat 8 bytes as the minimum but will try send more. mtu := int(route.MTU()) - if mtu > header.IPv4MinimumProcessableDatagramSize { - mtu = header.IPv4MinimumProcessableDatagramSize + const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize + if mtu > maxIPData { + mtu = maxIPData } - headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize - available := int(mtu) - headerLen + available := mtu - header.ICMPv4MinimumSize - if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { + if available < len(origIPHdr)+header.ICMPv4MinimumErrorPayloadSize { return nil } @@ -471,7 +465,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi payload.CapLength(payloadLen) icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, + ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize, Data: payload, }) diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 9515fde45..4550aacd6 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -157,14 +157,13 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { } h := header.IGMP(headerView) - // Temporarily reset the checksum field to 0 in order to calculate the proper - // checksum. - wantChecksum := h.Checksum() - h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - h.SetChecksum(wantChecksum) - - if gotChecksum != wantChecksum { + // As per RFC 1071 section 1.3, + // + // To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xFFFF { received.checksumErrors.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 7f03696ae..a05275a5b 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package ipv4 import ( - "errors" "fmt" "math" "reflect" @@ -322,8 +321,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -437,10 +436,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -612,7 +611,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -699,7 +699,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -780,17 +781,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options // rather than just throwing them away. - aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{}) - if err != nil { - switch { - case - errors.Is(err, header.ErrIPv4OptDuplicate), - errors.Is(err, errIPv4RecordRouteOptInvalidPointer), - errors.Is(err, errIPv4RecordRouteOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{}); optProblem != nil { + if optProblem.NeedICMP { + _ = e.protocol.returnError(&icmpReasonParamProblem{ + pointer: optProblem.Pointer, + }, pkt) e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() } @@ -1233,16 +1228,9 @@ func (*optionUsageEcho) actions() optionActions { } } -var ( - errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") - errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") - errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") - errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") -) - // handleTimestamp does any required processing on a Timestamp option // in place. -func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) *header.IPv4OptParameterProblem { flags := tsOpt.Flags() var entrySize uint8 switch flags { @@ -1253,7 +1241,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres header.IPv4OptionTimestampWithPredefinedIPFlag: entrySize = header.IPv4OptionTimestampWithAddrSize default: - return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSOFLWAndFLGOffset, + NeedICMP: true, + } } pointer := tsOpt.Pointer() @@ -1261,7 +1252,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres // Since the pointer is 1 based, and the header is 4 bytes long the // pointer must point beyond the header therefore 4 or less is bad. if pointer <= header.IPv4OptionTimestampHdrLength { - return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSPointerOffset, + NeedICMP: true, + } } // To simplify processing below, base further work on the array of timestamps // beyond the header, rather than on the whole option. Also to aid @@ -1295,14 +1289,17 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres // timestamp, but the overflow count is incremented by one. if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { // By definition we have nothing to do. - return 0, nil + return nil } if tsOpt.IncOverflow() != 0 { - return 0, nil + return nil } // The overflow count is also full. - return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSOFLWAndFLGOffset, + NeedICMP: true, + } } if nextSlot+entrySize > dataLength { // The data area isn't full but there isn't room for a new entry. @@ -1321,32 +1318,36 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres if dataLength%entrySize != 0 { // The Data section size should be a multiple of the expected // timestamp entry size. - return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: false, + } } // If the size is OK, the pointer must be corrupted. } - return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSPointerOffset, + NeedICMP: true, + } } if usage.actions().timestamp == optionProcess { tsOpt.UpdateTimestamp(localAddress, clock) } - return 0, nil + return nil } -var ( - errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") - errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") -) - // handleRecordRoute checks and processes a Record route option. It is much // like the timestamp type 1 option, but without timestamps. The passed in // address is stored in the option in the correct spot if possible. -func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) *header.IPv4OptParameterProblem { optlen := rrOpt.Size() if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { - return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: true, + } } pointer := rrOpt.Pointer() @@ -1356,7 +1357,10 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // Since the pointer is 1 based, and the header is 3 bytes long the // pointer must point beyond the header therefore 3 or less is bad. if pointer <= header.IPv4OptionRecordRouteHdrLength { - return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptRRPointerOffset, + NeedICMP: true, + } } // RFC 791 page 21 says @@ -1373,7 +1377,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // of these words is a copy/paste error from the timestamp option where // there are two failure reasons given. if pointer > optlen { - return 0, nil + return nil } // The data area isn't full but there isn't room for a new entry. @@ -1398,17 +1402,23 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // } if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { // Length is bad, not on integral number of slots. - return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: true, + } } // If not length, the fault must be with the pointer. } - return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptRRPointerOffset, + NeedICMP: true, + } } if usage.actions().recordRoute == optionVerify { - return 0, nil + return nil } rrOpt.StoreAddress(localAddress) - return 0, nil + return nil } // processIPOptions parses the IPv4 options and produces a new set of options @@ -1419,7 +1429,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // - The location of an error if there was one (or 0 if no error) // - If there is an error, information as to what it was was. // - The replacement option set. -func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, *header.IPv4OptParameterProblem) { stats := e.stats.ip opts := header.IPv4Options(orig) optIter := opts.MakeIterator() @@ -1433,21 +1443,23 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt // This will need tweaking when we start really forwarding packets // as we may need to get two addresses, for rx and tx interfaces. // We will also have to take usage into account. - prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) + prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) localAddress := prefixedAddress.Address - if err != nil { + if !ok { h := header.IPv4(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { - return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + return nil, &header.IPv4OptParameterProblem{ + NeedICMP: false, + } } localAddress = dstAddr } for { - option, done, err := optIter.Next() - if done || err != nil { - return optIter.ErrCursor, optIter.Finalize(), err + option, done, optProblem := optIter.Next() + if done || optProblem != nil { + return optIter.Finalize(), optProblem } optType := option.Type() if optType == header.IPv4OptionNOPType { @@ -1456,12 +1468,15 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt } if optType == header.IPv4OptionListEndType { optIter.PushNOPOrEnd(optType) - return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + return optIter.Finalize(), nil } // check for repeating options (multiple NOPs are OK) if seenOptions[optType] { - return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + return nil, &header.IPv4OptParameterProblem{ + Pointer: optIter.ErrCursor, + NeedICMP: true, + } } seenOptions[optType] = true @@ -1473,9 +1488,9 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt clock := e.protocol.stack.Clock() newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) - offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) - if err != nil { - return optIter.ErrCursor + offset, nil, err + if optProblem := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage); optProblem != nil { + optProblem.Pointer += optIter.ErrCursor + return nil, optProblem } optIter.ConsumeBuffer(optLen) } @@ -1485,9 +1500,9 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt if usage.actions().recordRoute != optionRemove { newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) - offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) - if err != nil { - return optIter.ErrCursor + offset, nil, err + if optProblem := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage); optProblem != nil { + optProblem.Pointer += optIter.ErrCursor + return nil, optProblem } optIter.ConsumeBuffer(optLen) } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index a9e137c24..dac7cbfd4 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -270,6 +270,11 @@ func TestIPv4Sanity(t *testing.T) { nicID = 1 randomSequence = 123 randomIdent = 42 + // In some cases Linux sets the error pointer to the start of the option + // (offset 0) instead of the actual wrong value, which is the length byte + // (offset 1). For compatibility we must do the same. Use this constant + // to indicate where this happens. + pointerOffsetForInvalidLength = 0 ) var ( ipv4Addr = tcpip.AddressWithPrefix{ @@ -439,6 +444,21 @@ func TestIPv4Sanity(t *testing.T) { replyOptions: header.IPv4Options{}, }, { + name: "bad option - no length", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 1, 1, 1, 68, + // ^-start of timestamp.. but no length.. + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + }, + { name: "bad option - length 0", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), @@ -448,7 +468,27 @@ func TestIPv4Sanity(t *testing.T) { // ^ 1, 2, 3, 4, }, - shouldFail: true, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, + }, + { + name: "bad option - length 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 1, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "bad option - length big", @@ -462,7 +502,11 @@ func TestIPv4Sanity(t *testing.T) { // space is not possible. (Second byte) 1, 2, 3, 4, }, - shouldFail: true, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { // This tests for some linux compatible behaviour. @@ -484,7 +528,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, { name: "multiple type 0 with room", @@ -589,7 +633,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, { name: "valid timestamp pointer", @@ -624,7 +668,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, // End of option list with illegal option after it, which should be ignored. { @@ -636,24 +680,31 @@ func TestIPv4Sanity(t *testing.T) { 68, 12, 13, 0x11, 192, 168, 1, 12, 1, 2, 3, 4, - 0, 10, 3, 99, + 0, 10, 3, 99, // EOL followed by junk }, replyOptions: header.IPv4Options{ 68, 12, 13, 0x21, 192, 168, 1, 12, 1, 2, 3, 4, - 0, 0, 0, 0, // 3 bytes unknown option - }, // ^ End of options hides following bytes. + 0, // End of Options hides following bytes. + 0, 0, 0, // 3 bytes unknown option removed. + }, }, { - // Timestamp with a size too small. + // Timestamp with a size much too small. name: "timestamp truncated", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, - options: header.IPv4Options{68, 1, 0, 0}, - // ^ Smallest possible is 8. - shouldFail: true, + options: header.IPv4Options{ + 68, 1, 0, 0, + // ^ Smallest possible is 8. Linux points at the 68. + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "single record route with room", @@ -751,7 +802,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { // Pointer must be 4 or more as it must point past the 3 byte header @@ -769,7 +820,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { // Pointer must be 4 or more as it must point past the 3 byte header @@ -808,8 +859,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, - replyOptions: header.IPv4Options{}, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { name: "duplicate record route", @@ -828,7 +878,6 @@ func TestIPv4Sanity(t *testing.T) { ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + 7, - replyOptions: header.IPv4Options{}, }, } @@ -884,7 +933,6 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } - ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: test.transportProtocol, @@ -2608,7 +2656,7 @@ func (*limitedMatcher) Name() string { } // Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go index 7620728f9..bee72c649 100644 --- a/pkg/tcpip/network/ipv4/stats.go +++ b/pkg/tcpip/network/ipv4/stats.go @@ -35,7 +35,7 @@ type Stats struct { } // IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (s *Stats) IsNetworkEndpointStats() {} +func (*Stats) IsNetworkEndpointStats() {} // IPStats implements stack.IPNetworkEndointStats func (s *Stats) IPStats() *tcpip.IPStats { diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index 84641bcf4..b28e7dcde 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -34,7 +34,7 @@ func (t *testInterface) ID() tcpip.NICID { return t.nicID } -func getKnownNICIDs(proto *protocol) []tcpip.NICID { +func knownNICIDs(proto *protocol) []tcpip.NICID { var nicIDs []tcpip.NICID for k := range proto.mu.eps { @@ -51,30 +51,28 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) nic := testInterface{nicID: 1} ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) - { - proto.mu.Lock() - foundEP, hasEP := proto.mu.eps[nic.ID()] - nicIDs := getKnownNICIDs(proto) - proto.mu.Unlock() + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() - if !hasEP { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("expected protocol to map endpoint %p to nic id %d, but endpoint %p was found instead", ep, nic.ID(), foundEP) - } + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) } ep.Close() - { - proto.mu.Lock() - _, hasEP := proto.mu.eps[nic.ID()] - nicIDs := getKnownNICIDs(proto) - proto.mu.Unlock() - if hasEP { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } + proto.mu.Lock() + _, hasEP := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) } } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ae5179d93..95efada3a 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -688,25 +688,38 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver. func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + nicID := nic.ID() + + p.mu.Lock() + netEP, ok := p.mu.eps[nicID] + p.mu.Unlock() + if !ok { + return tcpip.ErrNotConnected + } + remoteAddr := targetAddr if len(remoteLinkAddr) == 0 { remoteAddr = header.SolicitedNodeAddr(targetAddr) remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + if len(localAddr) == 0 { + addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) + if addressEndpoint == nil { + return tcpip.ErrNetworkUnreachable + } + + localAddr = addressEndpoint.AddressWithPrefix().Address + } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 { + return tcpip.ErrBadLocalAddress } - defer r.Release() - r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) @@ -714,20 +727,18 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot ns := header.NDPNeighborSolicit(packet.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) - packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, localAddr, remoteAddr, buffer.VectorisedView{})) - p.mu.Lock() - netEP, ok := p.mu.eps[nic.ID()] - p.mu.Unlock() - if !ok { - return tcpip.ErrNotConnected + if err := addIPHeader(localAddr, remoteAddr, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, header.IPv6ExtHdrSerializer{}); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) } + stat := netEP.stats.icmp.packetsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt); err != nil { + if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stat.dropped.Increment() return err } @@ -910,11 +921,11 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi // the error message packet exceed the minimum IPv6 MTU // [IPv6]. mtu := int(route.MTU()) - if mtu > header.IPv6MinimumMTU { - mtu = header.IPv6MinimumMTU + const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize + if mtu > maxIPv6Data { + mtu = maxIPv6Data } - headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize - available := int(mtu) - headerLen + available := mtu - header.ICMPv6ErrorHeaderSize if available < header.IPv6MinimumSize { return nil } @@ -928,7 +939,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, + ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize, Data: payload, }) newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index defea46b0..641c60b7c 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "bytes" "context" "net" "reflect" @@ -22,6 +23,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -638,7 +640,6 @@ func TestLinkResolution(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - payload := tcpip.SlicePayload(hdr.View()) // We can't send our payload directly over the route because that // doesn't provoke NDP discovery. @@ -648,8 +649,12 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } - if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { - t.Fatalf("ep.Write(_): %s", err) + { + var r bytes.Reader + r.Reset(hdr.View()) + if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { + t.Fatalf("ep.Write(_): %s", err) + } } for _, args := range []routeArgs{ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, @@ -1316,13 +1321,13 @@ func TestLinkAddressRequest(t *testing.T) { name: "Unicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: linkAddr1, - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: tcpip.ErrBadLocalAddress, }, { name: "Multicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: tcpip.ErrBadLocalAddress, }, { name: "Unicast with no local address available", @@ -1337,58 +1342,58 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - p := s.NetworkProtocolInstance(ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + p := s.NetworkProtocolInstance(ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := s.CreateNIC(nicID, linkEP); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } } - } - // We pass a test network interface to LinkAddressRequest with the same NIC - // ID and link endpoint used by the NIC we created earlier so that we can - // mock a link address request and observe the packets sent to the link - // endpoint even though the stack uses the real NIC. - if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) - } + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } - if test.expectedErr != nil { - return - } + if test.expectedErr != nil { + return + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) - } - if pkt.Route.RemoteAddress != test.expectedRemoteAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) - } - if pkt.Route.LocalAddress != lladdr1 { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) - } - checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), - checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedRemoteAddr), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), - )) + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = test.expectedRemoteLinkAddr + if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) + } + checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), + checker.SrcAddr(lladdr1), + checker.DstAddr(test.expectedRemoteAddr), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), + )) + }) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 37884505e..d658f9bcb 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -555,7 +555,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error { +func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error { extHdrsLen := extensionHeaders.Length() length := pkt.Size() + extensionHeaders.Length() if length > math.MaxUint16 { @@ -625,14 +625,14 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { + if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { return err } // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -718,7 +718,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe stats := e.stats.ip linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { + if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { return 0, err } @@ -747,8 +747,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -897,7 +897,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -955,7 +956,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index aa892d043..5276878a0 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -371,12 +371,10 @@ func TestAddIpv6Address(t *testing.T) { t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) } - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok { + t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber) + } else if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr) } }) } @@ -2575,7 +2573,7 @@ func (*limitedMatcher) Name() string { } // Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } @@ -2583,7 +2581,7 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, return false, false } -func getKnownNICIDs(proto *protocol) []tcpip.NICID { +func knownNICIDs(proto *protocol) []tcpip.NICID { var nicIDs []tcpip.NICID for k := range proto.mu.eps { @@ -2600,29 +2598,27 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) - { - proto.mu.Lock() - foundEP, hasEP := proto.mu.eps[nic.ID()] - nicIDs := getKnownNICIDs(proto) - proto.mu.Unlock() - if !hasEP { - t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("expected protocol to map endpoint %p to nic id %d, but endpoint %p was found instead", ep, nic.ID(), foundEP) - } + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) } ep.Close() - { - proto.mu.Lock() - _, hasEP := proto.mu.eps[nic.ID()] - nicIDs := getKnownNICIDs(proto) - proto.mu.Unlock() - if hasEP { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } + proto.mu.Lock() + _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEndpointAfterClose { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) } } diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 78d86e523..c376016e9 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -249,7 +249,7 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp Data: buffer.View(icmp).ToVectorisedView(), }) - if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.MLDHopLimit, }, extensionHeaders); err != nil { diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 41112a0c4..ca4ff621d 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -732,7 +732,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add }) sent := ndp.ep.stats.icmp.packetsSent - if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { @@ -1857,7 +1857,7 @@ func (ndp *ndpState) startSolicitingRouters() { }) sent := ndp.ep.stats.icmp.packetsSent - if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index b1a5a5510..7a22309e5 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -162,6 +162,11 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } +type linkResolutionResult struct { + linkAddr tcpip.LinkAddress + ok bool +} + // TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -231,35 +236,28 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } + ch := make(chan stack.LinkResolutionResult, 1) + err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { + ch <- r + }) - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) - } + wantInvalid := uint64(0) + wantErr := (*tcpip.Error)(nil) + wantSucccess := true + if len(test.expectedLinkAddr) == 0 { + wantInvalid = 1 + wantErr = tcpip.ErrWouldBlock + wantSucccess = false + } - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } + if err != wantErr { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { + t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + } + if got := invalid.Value(); got != wantInvalid { + t.Errorf("got invalid = %d, want = %d", got, wantInvalid) } }) } @@ -640,18 +638,12 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatal("expected an NDP NS response") } - if p.Route.LocalAddress != nicAddr { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } respNSDst := header.SolicitedNodeAddr(test.nsSrc) - if p.Route.RemoteAddress != respNSDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) - } - if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) + if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -803,35 +795,28 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } + ch := make(chan stack.LinkResolutionResult, 1) + err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { + ch <- r + }) - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) - } + wantInvalid := uint64(0) + wantErr := (*tcpip.Error)(nil) + wantSucccess := true + if len(test.expectedLinkAddr) == 0 { + wantInvalid = 1 + wantErr = tcpip.ErrWouldBlock + wantSucccess = false + } - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } + if err != wantErr { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { + t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + } + if got := invalid.Value(); got != wantInvalid { + t.Errorf("got invalid = %d, want = %d", got, wantInvalid) } }) } diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go index a2f2f4f78..0839be3cd 100644 --- a/pkg/tcpip/network/ipv6/stats.go +++ b/pkg/tcpip/network/ipv6/stats.go @@ -32,7 +32,7 @@ type Stats struct { } // IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (s *Stats) IsNetworkEndpointStats() {} +func (*Stats) IsNetworkEndpointStats() {} // IPStats implements stack.IPNetworkEndointStats func (s *Stats) IPStats() *tcpip.IPStats { diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index 652b92a21..bd62c4482 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -9,6 +9,7 @@ go_library( "testutil_unsafe.go", ], visibility = [ + "//pkg/tcpip/network/arp:__pkg__", "//pkg/tcpip/network/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index cf0a5fefe..db9b91815 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -8,7 +8,6 @@ go_binary( visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/rawfile", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 3b4f900e3..3d9954c84 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -41,7 +41,7 @@ package main import ( - "bufio" + "bytes" "fmt" "log" "math/rand" @@ -51,7 +51,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" @@ -71,24 +70,21 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) { close(ch) }() - r := bufio.NewReader(os.Stdin) - for { - v := buffer.NewView(1024) - n, err := r.Read(v) - if err != nil { - return - } - - v.CapLength(n) - for len(v) > 0 { - n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - if err != nil { - fmt.Println("Write failed:", err) - return + var b bytes.Buffer + if err := func() error { + for { + if _, err := b.ReadFrom(os.Stdin); err != nil { + return fmt.Errorf("b.ReadFrom failed: %w", err) } - v.TrimFront(int(n)) + for b.Len() != 0 { + if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil { + return fmt.Errorf("ep.Write failed: %s", err) + } + } } + }(); err != nil { + fmt.Println(err) } } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 3ac562756..ae9cf44e7 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -20,6 +20,7 @@ package main import ( + "bytes" "flag" "io" "log" @@ -58,7 +59,9 @@ func (e *tcpipError) Error() string { } func (e *endpointWriter) Write(p []byte) (int, error) { - n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(p) + n, err := e.ep.Write(&r, tcpip.WriteOptions{}) if err != nil { return int(n), &tcpipError{ inner: err, diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index f3ad40fdf..7eabbc599 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -15,11 +15,16 @@ package tcpip import ( + "math" "sync/atomic" "gvisor.dev/gvisor/pkg/sync" ) +// PacketOverheadFactor is used to multiply the value provided by the user on a +// SetSockOpt for setting the send/receive buffer sizes sockets. +const PacketOverheadFactor = 2 + // SocketOptionsHandler holds methods that help define endpoint specific // behavior for socket level socket options. These must be implemented by // endpoints to get notified when socket level options are set. @@ -48,6 +53,12 @@ type SocketOptionsHandler interface { // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. HasNIC(v int32) bool + + // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE. + GetSendBufferSize() (int64, *Error) + + // IsUnixSocket is invoked to check if the socket is of unix domain. + IsUnixSocket() bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -84,6 +95,27 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { return false } +// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize. +func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, *Error) { + return 0, nil +} + +// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket. +func (*DefaultSocketOptionsHandler) IsUnixSocket() bool { + return false +} + +// StackHandler holds methods to access the stack options. These must be +// implemented by the stack. +type StackHandler interface { + // Option allows retrieving stack wide options. + Option(option interface{}) *Error + + // TransportProtocolOption allows retrieving individual protocol level + // option values. + TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) *Error +} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -91,6 +123,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { type SocketOptions struct { handler SocketOptionsHandler + // StackHandler is initialized at the creation time and will not change. + stackHandler StackHandler `state:"manual"` + // These fields are accessed and modified using atomic operations. // broadcastEnabled determines whether datagram sockets are allowed to @@ -170,6 +205,14 @@ type SocketOptions struct { // bindToDevice determines the device to which the socket is bound. bindToDevice int32 + // getSendBufferLimits provides the handler to get the min, default and + // max size for send buffer. It is initialized at the creation time and + // will not change. + getSendBufferLimits GetSendBufferLimits `state:"manual"` + + // sendBufferSize determines the send buffer size for this socket. + sendBufferSize int64 + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -180,8 +223,10 @@ type SocketOptions struct { // InitHandler initializes the handler. This must be called before using the // socket options utility. -func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) { +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) { so.handler = handler + so.stackHandler = stack + so.getSendBufferLimits = getSendBufferLimits } func storeAtomicBool(addr *uint32, v bool) { @@ -518,3 +563,44 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { atomic.StoreInt32(&so.bindToDevice, bindToDevice) return nil } + +// GetSendBufferSize gets value for SO_SNDBUF option. +func (so *SocketOptions) GetSendBufferSize() (int64, *Error) { + if so.handler.IsUnixSocket() { + return so.handler.GetSendBufferSize() + } + return atomic.LoadInt64(&so.sendBufferSize), nil +} + +// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the +// stack handler should be invoked to set the send buffer size. +func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { + if so.handler.IsUnixSocket() { + return + } + + v := sendBufferSize + if notify { + // TODO(b/176170271): Notify waiters after size has grown. + // Make sure the send buffer size is within the min and max + // allowed. + ss := so.getSendBufferLimits(so.stackHandler) + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + // Multiply it by factor of 2. + if v > max { + v = max + } + + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min + } + } else { + v = math.MaxInt32 + } + } + atomic.StoreInt64(&so.sendBufferSize, v) +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 09c7811fa..04af933a6 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -267,11 +267,11 @@ const ( // dropped. // // TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from -// which address and nicName can be gathered. Currently, address is only -// needed for prerouting and nicName is only needed for output. +// which address can be gathered. Currently, address is only needed for +// prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { + if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(pkt, hook, nicName) { + if !rule.Filter.match(pkt, hook, inNicName, outNicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName) if hotdrop { return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 56a3e7861..fd9d61e39 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -210,8 +210,19 @@ type IPHeaderFilter struct { // filter will match packets that fail the source comparison. SrcInvert bool - // OutputInterface matches the name of the outgoing interface for the - // packet. + // InputInterface matches the name of the incoming interface for the packet. + InputInterface string + + // InputInterfaceMask masks the characters of the interface name when + // comparing with InputInterface. + InputInterfaceMask string + + // InputInterfaceInvert inverts the meaning of incoming interface check, + // i.e. when true the filter will match packets that fail the incoming + // interface comparison. + InputInterfaceInvert bool + + // OutputInterface matches the name of the outgoing interface for the packet. OutputInterface string // OutputInterfaceMask masks the characters of the interface name when @@ -228,7 +239,7 @@ type IPHeaderFilter struct { // // Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 // or IPv6 header length. -func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( // TODO(gvisor.dev/issue/170): Support other filter fields. @@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return false } - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(fl.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which - // begins with the name should be matched. - ifName := fl.OutputInterface - matches := nicName == ifName - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } - return fl.OutputInterfaceInvert != matches + switch hook { + case Prerouting, Input: + return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) + case Output: + return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) + case Forward, Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + return true + default: + panic(fmt.Sprintf("unknown hook: %d", hook)) } +} - return true +func matchIfName(nicName string, ifName string, invert bool) bool { + n := len(ifName) + if n == 0 { + // If the interface name is omitted in the filter, any interface will match. + return true + } + // If the interface name ends with '+', any interface which begins with the + // name should be matched. + var matches bool + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return matches != invert } // NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header @@ -320,7 +340,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 3c4fa341e..ba6d56a7d 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -32,6 +32,8 @@ var _ LinkAddressCache = (*linkAddrCache)(nil) // // This struct is safe for concurrent use. type linkAddrCache struct { + nic *NIC + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -79,6 +81,8 @@ type linkAddrEntry struct { // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry + cache *linkAddrCache + // TODO(gvisor.dev/issue/5150): move these fields under mu. // mu protects the fields below. mu sync.RWMutex @@ -93,17 +97,26 @@ type linkAddrEntry struct { done chan struct{} // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + onResolve []func(LinkResolutionResult) } func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { + res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} for _, callback := range e.onResolve { - callback(linkAddr, len(linkAddr) != 0) + callback(res) } e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0) } } @@ -174,8 +187,9 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { } *entry = linkAddrEntry{ - addr: k, - s: incomplete, + cache: c, + addr: k, + s: incomplete, } c.cache.table[k] = entry c.cache.lru.PushFront(entry) @@ -183,7 +197,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { c.cache.Lock() defer c.cache.Unlock() entry := c.getOrCreateEntryLocked(k) @@ -195,7 +209,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA if !time.Now().After(entry.expiration) { // Not expired. if onResolve != nil { - onResolve(entry.linkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.linkAddr, Success: true}) } return entry.linkAddr, nil, nil } @@ -264,8 +278,9 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt return true } -func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { +func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { c := &linkAddrCache{ + nic: nic, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 8c35067c6..88fbbf3fe 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -93,8 +93,14 @@ func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolv } } +func newEmptyNIC() *NIC { + n := &NIC{} + n.linkResQueue.init(n) + return n +} + func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] c.AddLinkAddress(e.addr, e.linkAddr) @@ -129,7 +135,7 @@ func TestCacheOverflow(t *testing.T) { } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup @@ -165,7 +171,7 @@ func TestCacheConcurrent(t *testing.T) { } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) linkRes := &testLinkAddressResolver{cache: c} e := testAddrs[0] @@ -177,7 +183,7 @@ func TestCacheAgeLimit(t *testing.T) { } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) e := testAddrs[0] l2 := e.linkAddr + "2" c.AddLinkAddress(e.addr, e.linkAddr) @@ -206,7 +212,7 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) linkRes := &testLinkAddressResolver{cache: c} for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) @@ -232,7 +238,7 @@ func TestCacheResolution(t *testing.T) { } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} var requestCount uint32 @@ -265,7 +271,7 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) + c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} e := testAddrs[0] diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 270f5fb1a..d7bbb25ea 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -45,6 +45,8 @@ const ( linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") + defaultPrefixLen = 128 + // Extra time to use when waiting for an async event to occur. defaultAsyncPositiveEventTimeout = 10 * time.Second @@ -330,8 +332,12 @@ func TestDADDisabled(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } // Should get the address immediately since we should not have performed @@ -344,12 +350,8 @@ func TestDADDisabled(t *testing.T) { default: t.Fatal("expected DAD event") } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatal(err) } // We should not have sent any NDP NS messages. @@ -440,24 +442,24 @@ func TestDADResolve(t *testing.T) { NIC: nicID, }}) - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Make sure the address does not resolve before the resolution time has // passed. time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) } // Should not get a route even if we specify the local address as the // tentative address. @@ -493,10 +495,8 @@ func TestDADResolve(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if addr.Address != addr1 { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Error(err) } // Should get a route using the address now that it is resolved. { @@ -662,12 +662,8 @@ func TestDADFail(t *testing.T) { // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Receive a packet to simulate an address conflict. @@ -691,12 +687,8 @@ func TestDADFail(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Attempting to add the address again should not fail if the address's @@ -777,12 +769,8 @@ func TestDADStop(t *testing.T) { } // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } test.stopFn(t, s) @@ -800,12 +788,8 @@ func TestDADStop(t *testing.T) { } if !test.skipFinalAddrCheck { - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } } @@ -901,26 +885,25 @@ func TestSetNDPConfigurations(t *testing.T) { } // Add addresses for each NIC. - if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) } - if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err) + addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) } expectDADEvent(nicID2, addr2) - if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err) + addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) } expectDADEvent(nicID3, addr3) // Address should not be considered bound to NIC(1) yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Should get the address on NIC(2) and NIC(3) @@ -928,31 +911,19 @@ func TestSetNDPConfigurations(t *testing.T) { // it as the stack was configured to not do DAD by // default and we only updated the NDP configurations on // NIC(1). - addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2) + if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { + t.Fatal(err) } - addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3) + if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { + t.Fatal(err) } // Sleep until right (500ms before) before resolution to // make sure the address didn't resolve on NIC(1) yet. const delta = 500 * time.Millisecond time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Wait for DAD to resolve. @@ -970,12 +941,8 @@ func TestSetNDPConfigurations(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { + t.Fatal(err) } }) } @@ -2946,10 +2913,8 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { @@ -3094,10 +3059,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { @@ -3244,10 +3207,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("should not have %s in the list of addresses", addr2) } // Should not have any primary endpoints. - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); got != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) @@ -3621,10 +3582,8 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index acee72572..204196d00 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() @@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // a node continues sending packets to that neighbor using the cached // link-layer address." if onResolve != nil { - onResolve(entry.neigh.LinkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true}) } return entry.neigh, nil, nil case Unknown, Incomplete, Failed: diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index db27cbc73..dbdb51bb4 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -1188,12 +1188,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1247,12 +1244,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1423,12 +1417,9 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1539,12 +1530,9 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1576,15 +1564,9 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1627,15 +1609,9 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1674,15 +1650,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1713,9 +1683,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1772,12 +1742,9 @@ func BenchmarkCacheClear(b *testing.B) { b.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - b.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 75afb3001..53ac9bb6e 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -96,7 +96,7 @@ type neighborEntry struct { done chan struct{} // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + onResolve []func(LinkResolutionResult) isRouter bool job *tcpip.Job @@ -143,13 +143,22 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd // // Precondition: e.mu MUST be locked. func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded} for _, callback := range e.onResolve { - callback(e.neigh.LinkAddr, succeeded) + callback(res) } e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f2bca93d3..1bbfe6213 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -139,9 +139,9 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), } - nic.linkResQueue.init() + nic.linkResQueue.init(nic) + nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. @@ -303,6 +303,10 @@ func (n *NIC) IsLoopback() bool { // WritePacket implements NetworkLinkEndpoint. func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) + return err +} +func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { // As per relevant RFCs, we should queue packets while we wait for link // resolution to complete. // @@ -320,16 +324,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // be limited to some small value. When a queue overflows, the new arrival // SHOULD replace the oldest entry. Once address resolution completes, the // node transmits any queued packets. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - r.Acquire() - n.linkResQueue.enqueue(ch, r, protocol, pkt) - return nil - } - return err - } - - return n.writePacket(r.Fields(), gso, protocol, pkt) + return n.linkResQueue.enqueue(r, gso, protocol, pkt) } // WritePacketToRemote implements NetworkInterface. @@ -344,6 +339,9 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { return err } @@ -355,9 +353,17 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN // WritePackets implements NetworkLinkEndpoint. func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution - // is being peformed like WritePacket. - writtenPackets, err := n.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) + return n.enqueuePacketBuffer(r, gso, protocol, &pkts) +} + +func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol + } + + writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { @@ -555,7 +561,7 @@ func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress } -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if n.neigh != nil { entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) return entry.LinkAddr, ch, err diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 81d8ff6e8..c4769b17e 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -28,119 +28,219 @@ const ( maxPendingPacketsPerResolution = 256 ) +// pendingPacketBuffer is a pending packet buffer. +// +// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use +// WritePackets so we can use a PacketBufferList everywhere. +type pendingPacketBuffer interface { + len() int +} + +func (*PacketBuffer) len() int { + return 1 +} + +func (p *PacketBufferList) len() int { + return p.Len() +} + type pendingPacket struct { - route *Route - proto tcpip.NetworkProtocolNumber - pkt *PacketBuffer + routeInfo RouteInfo + gso *GSO + proto tcpip.NetworkProtocolNumber + pkt pendingPacketBuffer } // packetsPendingLinkResolution is a queue of packets pending link resolution. // // Once link resolution completes successfully, the packets will be written. type packetsPendingLinkResolution struct { - sync.Mutex + nic *NIC - // The packets to send once the resolver completes. - packets map[<-chan struct{}][]pendingPacket + mu struct { + sync.Mutex - // FIFO of channels used to cancel the oldest goroutine waiting for - // link-address resolution. - cancelChans []chan struct{} -} + // The packets to send once the resolver completes. + // + // The link resolution channel is used as the key for this map. + packets map[<-chan struct{}][]pendingPacket -func (f *packetsPendingLinkResolution) init() { - f.Lock() - defer f.Unlock() - f.packets = make(map[<-chan struct{}][]pendingPacket) + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + // + // cancelChans holds the same channels that are used as keys to packets. + cancelChans []<-chan struct{} + } } -func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber) { - r.Stats().IP.OutgoingPacketErrors.Increment() +func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { + n := uint64(pkt.len()) + f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n) - // ok may be false if the endpoint's stats do not collect IP-related data. - if ipEndpointStats, ok := r.outgoingNIC.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { - ipEndpointStats.IPStats().OutgoingPacketErrors.Increment() + if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { + ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n) } } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - f.Lock() - defer f.Unlock() +func (f *packetsPendingLinkResolution) init(nic *NIC) { + f.mu.Lock() + defer f.mu.Unlock() + f.nic = nic + f.mu.packets = make(map[<-chan struct{}][]pendingPacket) +} - packets, ok := f.packets[ch] - if len(packets) == maxPendingPacketsPerResolution { - p := packets[0] - packets[0] = pendingPacket{} - packets = packets[1:] +// dequeue any pending packets associated with ch. +// +// If success is true, packets will be written and sent to the given remote link +// address. +func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) { + f.mu.Lock() + packets, ok := f.mu.packets[ch] + delete(f.mu.packets, ch) - incrementOutgoingPacketErrors(r, proto) + if ok { + for i, cancelChan := range f.mu.cancelChans { + if cancelChan == ch { + f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...) + break + } + } + } + + f.mu.Unlock() + + if ok { + f.dequeuePackets(packets, linkAddr, success) + } +} - p.route.Release() +func (f *packetsPendingLinkResolution) writePacketBuffer(r RouteInfo, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { + switch pkt := pkt.(type) { + case *PacketBuffer: + if err := f.nic.writePacket(r, gso, proto, pkt); err != nil { + return 0, err + } + return 1, nil + case *PacketBufferList: + return f.nic.writePackets(r, gso, proto, *pkt) + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) } +} - if l := len(packets); l >= maxPendingPacketsPerResolution { - panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) +// enqueue a packet to be sent once link resolution completes. +// +// If the maximum number of pending resolutions is reached, the packets +// associated with the oldest link resolution will be dequeued as if they failed +// link resolution. +func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { + f.mu.Lock() + // Make sure we attempt resolution while holding f's lock so that we avoid + // a race where link resolution completes before we enqueue the packets. + // + // A @ T1: Call ResolvedFields (get link resolution channel) + // B @ T2: Complete link resolution, dequeue pending packets + // C @ T1: Enqueue packet that already completed link resolution (which will + // never dequeue) + // + // To make sure B does not interleave with A and C, we make sure A and C are + // done while holding the lock. + routeInfo, ch, err := r.resolvedFields(nil) + switch err { + case nil: + // The route resolved immediately, so we don't need to wait for link + // resolution to send the packet. + f.mu.Unlock() + return f.writePacketBuffer(routeInfo, gso, proto, pkt) + case tcpip.ErrWouldBlock: + // We need to wait for link resolution to complete. + default: + f.mu.Unlock() + return 0, err } - f.packets[ch] = append(packets, pendingPacket{ - route: r, - proto: proto, - pkt: pkt, + defer f.mu.Unlock() + + packets, ok := f.mu.packets[ch] + packets = append(packets, pendingPacket{ + routeInfo: routeInfo, + gso: gso, + proto: proto, + pkt: pkt, }) - if ok { - return - } + if len(packets) > maxPendingPacketsPerResolution { + f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt) + packets[0] = pendingPacket{} + packets = packets[1:] - // Wait for the link-address resolution to complete. - cancel := f.newCancelChannelLocked() - go func() { - cancelled := false - select { - case <-ch: - case <-cancel: - cancelled = true + if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution { + panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution)) } + } - f.Lock() - packets, ok := f.packets[ch] - delete(f.packets, ch) - f.Unlock() + f.mu.packets[ch] = packets - if !ok { - panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) - } + if ok { + return pkt.len(), nil + } - for _, p := range packets { - if cancelled || p.route.IsResolutionRequired() { - incrementOutgoingPacketErrors(r, proto) + cancelledPackets := f.newCancelChannelLocked(ch) - if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { - linkResolvableEP.HandleLinkResolutionFailure(pkt) - } - } else { - p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt) - } - p.route.Release() - } - }() + if len(cancelledPackets) != 0 { + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as handing link resolution failures may be a costly operation. + go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */) + } + + return pkt.len(), nil } -// newCancelChannel creates a channel that can cancel a pending forwarding -// activity. The oldest channel is closed if the number of open channels would -// exceed maxPendingResolutions. -func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { - if len(f.cancelChans) == maxPendingResolutions { - ch := f.cancelChans[0] - f.cancelChans[0] = nil - f.cancelChans = f.cancelChans[1:] - close(ch) +// newCancelChannelLocked appends the link resolution channel to a FIFO. If the +// maximum number of pending resolutions is reached, the oldest channel will be +// removed and its associated pending packets will be returned. +func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket { + f.mu.cancelChans = append(f.mu.cancelChans, newCH) + if len(f.mu.cancelChans) <= maxPendingResolutions { + return nil } - if l := len(f.cancelChans); l >= maxPendingResolutions { + + ch := f.mu.cancelChans[0] + f.mu.cancelChans[0] = nil + f.mu.cancelChans = f.mu.cancelChans[1:] + if l := len(f.mu.cancelChans); l > maxPendingResolutions { panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) } - ch := make(chan struct{}) - f.cancelChans = append(f.cancelChans, ch) - return ch + packets, ok := f.mu.packets[ch] + if !ok { + panic("must have a packet queue for an uncancelled channel") + } + delete(f.mu.packets, ch) + + return packets +} + +func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) { + for _, p := range packets { + if success { + p.routeInfo.RemoteLinkAddress = linkAddr + _, _ = f.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + } else { + f.incrementOutgoingPacketErrors(p.proto, p.pkt) + + if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok { + switch pkt := p.pkt.(type) { + case *PacketBuffer: + linkResolvableEP.HandleLinkResolutionFailure(pkt) + case *PacketBufferList: + for pb := pkt.Front(); pb != nil; pb = pb.Next() { + linkResolvableEP.HandleLinkResolutionFailure(pb) + } + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) + } + } + } + } } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1ff7b3a37..d9a8554e2 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -86,12 +86,21 @@ type RouteInfo struct { RemoteLinkAddress tcpip.LinkAddress } -// Fields returns a RouteInfo with all of r's exported fields. This allows -// callers to store the route's fields without retaining a reference to it. +// Fields returns a RouteInfo with all of the known values for the route's +// fields. +// +// If any fields are unknown (e.g. remote link address when it is waiting for +// link address resolution), they will be unset. func (r *Route) Fields() RouteInfo { + r.mu.RLock() + defer r.mu.RUnlock() + return r.fieldsLocked() +} + +func (r *Route) fieldsLocked() RouteInfo { return RouteInfo{ routeInfo: r.routeInfo, - RemoteLinkAddress: r.RemoteLinkAddress(), + RemoteLinkAddress: r.mu.remoteLinkAddress, } } @@ -306,29 +315,45 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. +// ResolvedFieldsResult is the result of a route resolution attempt. +type ResolvedFieldsResult struct { + RouteInfo RouteInfo + Success bool +} + +// ResolvedFields attempts to resolve the remote link address if it is not +// known. // -// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. -// waiting for ARP reply). If address resolution is required, a notification -// channel is also returned for the caller to block on. The channel is closed -// once address resolution is complete (successful or not). If a callback is -// provided, it will be called when address resolution is complete, regardless +// If a callback is provided, it will be called before ResolvedFields returns +// when address resolution is not required. If address resolution is required, +// the callback will be called once address resolution is complete, regardless // of success or failure. -func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { - r.mu.Lock() +// +// Note, the route will not cache the remote link address when address +// resolution completes. +func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) *tcpip.Error { + _, _, err := r.resolvedFields(afterResolve) + return err +} - if !r.isResolutionRequiredRLocked() { - // Nothing to do if there is no cache (which does the resolution on cache miss) or - // link address is already known. - r.mu.Unlock() - return nil, nil +// resolvedFields is like ResolvedFields but also returns a notification channel +// when address resolution is required. This channel will become readable once +// address resolution is complete. +// +// The route's fields will also be returned, regardless of whether address +// resolution is required or not. +func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, *tcpip.Error) { + r.mu.RLock() + fields := r.fieldsLocked() + resolutionRequired := r.isResolutionRequiredRLocked() + r.mu.RUnlock() + if !resolutionRequired { + if afterResolve != nil { + afterResolve(ResolvedFieldsResult{RouteInfo: fields, Success: true}) + } + return fields, nil, nil } - // Increment the route's reference count because finishResolution retains a - // reference to the route and releases it when called. - r.acquireLocked() - r.mu.Unlock() - nextAddr := r.NextHop if nextAddr == "" { nextAddr = r.RemoteAddress @@ -341,18 +366,20 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } - finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { - if ok { - r.ResolveWith(linkAddress) - } + afterResolveFields := fields + linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { if afterResolve != nil { - afterResolve() + if r.Success { + afterResolveFields.RemoteLinkAddress = r.LinkAddress + } + + afterResolve(ResolvedFieldsResult{RouteInfo: afterResolveFields, Success: r.Success}) } - r.Release() + }) + if err == nil { + fields.RemoteLinkAddress = linkAddr } - - _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) - return ch, err + return fields, ch, err } // local returns true if the route is a local route. @@ -371,11 +398,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { - return false - } - - return r.linkRes != nil + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 4685fa4cf..e9c5db4c3 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -444,7 +444,7 @@ type Stack struct { // sendBufferSize holds the min/default/max send buffer sizes for // endpoints other than TCP. - sendBufferSize SendBufferSizeOption + sendBufferSize tcpip.SendBufferSizeOption // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. @@ -646,7 +646,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), - sendBufferSize: SendBufferSizeOption{ + sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, @@ -1196,19 +1196,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { // GetMainNICAddress returns the first non-deprecated primary address and prefix // for the given NIC and protocol. If no non-deprecated primary address exists, -// a deprecated primary address and prefix will be returned. Returns an error if +// a deprecated primary address and prefix will be returned. Returns false if // the NIC doesn't exist and an empty value if the NIC doesn't have a primary // address for the given protocol. -func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + return tcpip.AddressWithPrefix{}, false } - return nic.primaryAddress(protocol), nil + return nic.primaryAddress(protocol), true } func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { @@ -1527,9 +1527,13 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAd return nil } -// GetLinkAddress finds the link address corresponding to a neighbor's address. -// -// Returns a link address for the remote address, if readily available. +// LinkResolutionResult is the result of a link address resolution attempt. +type LinkResolutionResult struct { + LinkAddress tcpip.LinkAddress + Success bool +} + +// GetLinkAddress finds the link address corresponding to a network address. // // Returns ErrNotSupported if the stack is not configured with a link address // resolver for the specified network protocol. @@ -1538,30 +1542,33 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAd // with a notification channel for the caller to block on. Triggers address // resolution asynchronously. // -// If onResolve is provided, it will be called either immediately, if -// resolution is not required, or when address resolution is complete, with -// the resolved link address and whether resolution succeeded. After any -// callbacks have been called, the returned notification channel is closed. +// onResolve will be called either immediately, if resolution is not required, +// or when address resolution is complete, with the resolved link address and +// whether resolution succeeded. // // If specified, the local address must be an address local to the interface // the neighbor cache belongs to. The local address is the source address of // a packet prompting NUD/link address resolution. -// -// TODO(gvisor.dev/issue/5151): Don't return the link address. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) *tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return "", nil, tcpip.ErrUnknownNICID + return tcpip.ErrUnknownNICID } linkRes, ok := s.linkAddrResolvers[protocol] if !ok { - return "", nil, tcpip.ErrNotSupported + return tcpip.ErrNotSupported + } + + if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil } - return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + return err } // Neighbors returns all IP to MAC address associations. @@ -1622,25 +1629,25 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CheckRegisterTransportEndpoint checks if an endpoint can be registered with // the stack transport dispatcher. -func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { +func (s *Stack) UnregisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. -func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { +func (s *Stack) StartTransportEndpointCleanup(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.cleanupEndpointsMu.Lock() s.cleanupEndpoints[ep] = struct{}{} s.cleanupEndpointsMu.Unlock() @@ -1665,13 +1672,13 @@ func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, tran // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. -func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. -func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { +func (s *Stack) UnregisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { s.demux.unregisterRawEndpoint(netProto, transProto, ep) } diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 0b093e6c5..92e70f94e 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -14,7 +14,9 @@ package stack -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // MinBufferSize is the smallest size of a receive or send buffer. @@ -29,14 +31,6 @@ const ( DefaultMaxBufferSize = 4 << 20 // 4 MiB ) -// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to -// get/set the default, min and max send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - // ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to // get/set the default, min and max receive buffer sizes. type ReceiveBufferSizeOption struct { @@ -48,7 +42,7 @@ type ReceiveBufferSizeOption struct { // SetOption allows setting stack wide options. func (s *Stack) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { - case SendBufferSizeOption: + case tcpip.SendBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { @@ -88,7 +82,7 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { // Option allows retrieving stack wide options. func (s *Stack) Option(option interface{}) *tcpip.Error { switch v := option.(type) { - case *SendBufferSizeOption: + case *tcpip.SendBufferSizeOption: s.mu.RLock() *v = s.sendBufferSize s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b9ef455e5..0f02f1d53 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,6 +60,15 @@ const ( protocolNumberOffset = 2 ) +func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { + if addr, ok := s.GetMainNICAddress(nicID, proto); !ok { + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto) + } else if addr != want { + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want) + } + return nil +} + // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and // received packets; the counts of all endpoints are aggregated in the protocol // descriptor. @@ -1873,20 +1882,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) + gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if len(primaryAddrAdded) == 0 { // No primary addresses present. if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr) } } else { // At least one primary address was added, verify the returned // address is in the list of primary addresses we added. if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded) } } }) @@ -1927,12 +1936,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { - t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { @@ -1940,12 +1945,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get no address after removal. - gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } }) } @@ -2486,12 +2487,12 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } } - gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + // Check that we get no address after removal. + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } - if gotMainAddr != expectedMainAddr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) + if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { + t.Fatal(err) } }) } @@ -2537,12 +2538,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } }) } @@ -2573,12 +2570,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // Address should not be considered bound to the // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } linkLocalAddr := header.LinkLocalAddr(linkAddr1) @@ -2596,12 +2589,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { + t.Fatal(err) } } @@ -2633,17 +2622,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { t.Fatal("AddAddressWithOptions failed:", err) } - addr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + addr, ok := s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if pi == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) } } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) } { @@ -2722,18 +2711,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { if err := s.RemoveAddress(1, "\x03"); err != nil { t.Fatalf("RemoveAddress failed: %v", err) } - addr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatalf("s.GetMainNICAddress failed: %v", err) + addr, ok = s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) - + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) } } else { if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) } } }) @@ -3259,12 +3247,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Address should be tentative so it should not be a main address. - got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Enabling the NIC should start DAD for the address. @@ -3276,12 +3260,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Address should not be considered bound to the NIC yet (DAD ongoing). - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Wait for DAD to resolve. @@ -3296,12 +3276,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { + t.Fatal(err) } // Enabling the NIC again should be a no-op. @@ -3311,12 +3287,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { + t.Fatal(err) } } @@ -3364,21 +3336,21 @@ func TestStackSendBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - ss stack.SendBufferSizeOption + ss tcpip.SendBufferSizeOption err *tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, // Valid Configurations - {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -3387,7 +3359,7 @@ func TestStackSendBufferSizeOption(t *testing.T) { if err := s.SetOption(tc.ss); err != tc.err { t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) } - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if tc.err == nil { if err := s.Option(&ss); err != nil { t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) @@ -3790,20 +3762,16 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } else if gotAddr != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } // Should still get the address when the NIC is diabled. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("DisableNIC(%d): %s", nicID, err) } - if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } else if gotAddr != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } } @@ -4384,10 +4352,58 @@ func TestGetLinkAddressErrors(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID) + if err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrUnknownNICID) + } + if err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrNotSupported) + } +} + +func TestStaticGetLinkAddress(t *testing.T) { + const ( + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + }) + if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4", + proto: ipv4.ProtocolNumber, + addr: header.IPv4Broadcast, + expectedLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "IPv6", + proto: ipv6.ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, } - if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch := make(chan stack.LinkResolutionResult, 1) + if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) { + ch <- r + }); err != nil { + t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err) + } + + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 57e1f8354..de4b5fbdc 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -194,7 +194,7 @@ func TestTransportDemuxerRegister(t *testing.T) { if !ok { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 9d39533a1..c49427c4c 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "bytes" "io" "testing" @@ -67,9 +68,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { return &f.ops } -func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} - ep.ops.InitHandler(ep) +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} + ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) return ep } @@ -95,10 +96,11 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, tcpip.ErrNoRoute } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, Data: buffer.View(v).ToVectorisedView(), @@ -147,7 +149,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { r.Release() return err @@ -188,7 +190,6 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { if err := f.proto.stack.RegisterTransportEndpoint( - a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, @@ -232,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress, route: route, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) f.acceptQueue = append(f.acceptQueue, ep) } @@ -280,7 +281,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil + return newFakeTransportEndpoint(f, netProto, f.stack), nil } func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -520,8 +521,10 @@ func TestTransportSend(t *testing.T) { } // Create buffer that will hold the payload. - view := buffer.NewView(30) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + b := make([]byte, 30) + var r bytes.Reader + r.Reset(b) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("write failed: %v", err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4f59e4ff7..812ee36ed 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -29,6 +29,7 @@ package tcpip import ( + "bytes" "errors" "fmt" "io" @@ -194,7 +195,7 @@ type ErrSaveRejection struct { } // Error returns a sensible description of the save rejection error. -func (e ErrSaveRejection) Error() string { +func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } @@ -471,30 +472,15 @@ type FullAddress struct { // This interface allows the endpoint to request the amount of data it needs // based on internal buffers without exposing them. type Payloader interface { - // FullPayload returns all available bytes. - FullPayload() ([]byte, *Error) + io.Reader - // Payload returns a slice containing at most size bytes. - Payload(size int) ([]byte, *Error) + // Len returns the number of bytes of the unread portion of the + // Reader. + Len() int } -// SlicePayload implements Payloader for slices. -// -// This is typically used for tests. -type SlicePayload []byte - -// FullPayload implements Payloader.FullPayload. -func (s SlicePayload) FullPayload() ([]byte, *Error) { - return s, nil -} - -// Payload implements Payloader.Payload. -func (s SlicePayload) Payload(size int) ([]byte, *Error) { - if size > len(s) { - size = len(s) - } - return s[:size], nil -} +var _ Payloader = (*bytes.Buffer)(nil) +var _ Payloader = (*bytes.Reader)(nil) var _ io.Writer = (*SliceWriter)(nil) @@ -840,10 +826,6 @@ const ( // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption - // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to - // specify the send buffer size option. - SendBufferSizeOption - // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the receive buffer size option. ReceiveBufferSizeOption @@ -1248,6 +1230,31 @@ type IPPacketInfo struct { DestinationAddr Address } +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + // Min is the minimum size for send buffer. + Min int + + // Default is the default size for send buffer. + Default int + + // Max is the maximum size for send buffer. + Max int +} + +// GetSendBufferLimits is used to get the send buffer size limits. +type GetSendBufferLimits func(StackHandler) SendBufferSizeOption + +// GetStackSendBufferLimits is used to get default, min and max send buffer size. +func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption { + var ss SendBufferSizeOption + if err := so.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + return ss +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -1681,6 +1688,8 @@ type IPStats struct { // ARPStats collects ARP-specific stats. type ARPStats struct { + // LINT.IfChange(ARPStats) + // PacketsReceived is the number of ARP packets received from the link layer. PacketsReceived *StatCounter @@ -1708,10 +1717,6 @@ type ARPStats struct { // ARP request with a bad local address. OutgoingRequestBadLocalAddressErrors *StatCounter - // OutgoingRequestNetworkUnreachableErrors is the number of failures to send - // an ARP request with a network unreachable error. - OutgoingRequestNetworkUnreachableErrors *StatCounter - // OutgoingRequestsDropped is the number of ARP requests which failed to write // to a link-layer endpoint. OutgoingRequestsDropped *StatCounter @@ -1730,6 +1735,8 @@ type ARPStats struct { // OutgoingRepliesSent is the number of ARP replies successfully written to a // link-layer endpoint. OutgoingRepliesSent *StatCounter + + // LINT.ThenChange(network/arp/stats.go:multiCounterARPStats) } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 1742a178d..218b218e7 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -7,6 +7,7 @@ go_test( size = "small", srcs = [ "forward_test.go", + "iptables_test.go", "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index ac9670f9a..aedf1845e 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -436,9 +436,10 @@ func TestForwarding(t *testing.T) { write := func(ep tcpip.Endpoint, data []byte) { t.Helper() - dataPayload := tcpip.SlicePayload(data) + var r bytes.Reader + r.Reset(data) var wOpts tcpip.WriteOptions - n, err := ep.Write(dataPayload, wOpts) + n, err := ep.Write(&r, wOpts) if err != nil { t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) } @@ -486,7 +487,7 @@ func TestForwarding(t *testing.T) { read(serverCH, serverEP, data, clientAddr) - data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + data = []byte{5, 6, 7, 8, 9, 10, 11, 12} write(serverEP, data) read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) }) diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go new file mode 100644 index 000000000..21a8dd291 --- /dev/null +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -0,0 +1,336 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type inputIfNameMatcher struct { + name string +} + +var _ stack.Matcher = (*inputIfNameMatcher)(nil) + +func (*inputIfNameMatcher) Name() string { + return "inputIfNameMatcher" +} + +func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { + return (hook == stack.Input && im.name != "" && im.name == inNicName), false +} + +const ( + nicID = 1 + nicName = "nic1" + anotherNicName = "nic2" + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + srcAddrV4 = "\x0a\x00\x00\x01" + dstAddrV4 = "\x0a\x00\x00\x02" + srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + payloadSize = 20 +) + +func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + }) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + } + return s, e +} + +func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + e := channel.New(0, header.IPv4MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + } + return s, e +} + +func genPacketV6() *stack.PacketBuffer { + pktSize := header.IPv6MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv6(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadSize, + TransportProtocol: 99, + HopLimit: 255, + SrcAddr: srcAddrV6, + DstAddr: dstAddrV6, + }) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func genPacketV4() *stack.PacketBuffer { + pktSize := header.IPv4MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv4Fields{ + TOS: 0, + TotalLength: uint16(pktSize), + ID: 1, + Flags: 0, + FragmentOffset: 16, + TTL: 48, + Protocol: 99, + SrcAddr: srcAddrV4, + DstAddr: dstAddrV4, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func TestIPTablesStatsForInput(t *testing.T) { + tests := []struct { + name string + setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) + setupFilter func(*testing.T, *stack.Stack) + genPacket func() *stack.PacketBuffer + proto tcpip.NetworkProtocolNumber + expectReceived int + expectInputDropped int + }{ + { + name: "IPv6 Accept", + setupStack: genStackV6, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept", + setupStack: genStackV4, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface matches)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface matches)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface does not match but invert is true)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface does not match but invert is true)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match using a matcher)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match using a matcher)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := test.setupStack(t) + test.setupFilter(t, s) + e.InjectInbound(test.proto, test.genPacket()) + + if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { + t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) + } + if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { + t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index af32d3009..f85164c5b 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" @@ -32,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -207,8 +209,10 @@ func TestPing(t *testing.T) { defer ep.Close() icmpBuf := test.icmpBuf(t) + var r bytes.Reader + r.Reset(icmpBuf) wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(_, _): %s", err) } else if want := int64(len(icmpBuf)); n != want { t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) @@ -358,9 +362,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) { // Wait for an error due to link resolution failing, or the endpoint to be // writable. <-ch + var r bytes.Reader + r.Reset([]byte{0}) var wOpts tcpip.WriteOptions - if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr { - t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + if n, err := clientEP.Write(&r, wOpts); err != test.expectedWriteErr { + t.Errorf("got clientEP.Write(_, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) } if test.expectedWriteErr == nil { @@ -404,20 +410,34 @@ func TestGetLinkAddress(t *testing.T) { ) tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedLinkAddr bool + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedOk bool }{ { - name: "IPv4", + name: "IPv4 resolvable", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedOk: true, }, { - name: "IPv6", + name: "IPv6 resolvable", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedOk: true, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + expectedOk: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + expectedOk: false, }, } @@ -432,27 +452,279 @@ func TestGetLinkAddress(t *testing.T) { host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - for i := 0; i < 2; i++ { - addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {}) - var want *tcpip.Error - if i == 0 { - want = tcpip.ErrWouldBlock + ch := make(chan stack.LinkResolutionResult, 1) + if err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }); err != tcpip.ErrWouldBlock { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, tcpip.ErrWouldBlock) + } + wantRes := stack.LinkResolutionResult{Success: test.expectedOk} + if test.expectedOk { + wantRes.LinkAddress = linkAddr2 + } + if diff := cmp.Diff(wantRes, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) + } + }) + } +} + +func TestRouteResolvedFields(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + localAddr tcpip.Address + remoteAddr tcpip.Address + immediatelyResolvable bool + expectedSuccess bool + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4 immediately resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: header.IPv4AllSystems, + immediatelyResolvable: true, + expectedSuccess: true, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems), + }, + { + name: "IPv6 immediately resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: header.IPv6AllNodesMulticastAddress, + immediatelyResolvable: true, + expectedSuccess: true, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, + { + name: "IPv4 resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: true, + expectedLinkAddr: linkAddr2, + }, + { + name: "IPv6 resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: true, + expectedLinkAddr: linkAddr2, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, useNeighborCache := range []bool{true, false} { + t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + UseNeighborCache: useNeighborCache, + } + + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + var wantRouteInfo stack.RouteInfo + wantRouteInfo.LocalLinkAddress = linkAddr1 + wantRouteInfo.LocalAddress = test.localAddr + wantRouteInfo.RemoteAddress = test.remoteAddr + wantRouteInfo.NetProto = test.netProto + wantRouteInfo.Loop = stack.PacketOut + wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr + + ch := make(chan stack.ResolvedFieldsResult, 1) + + if !test.immediatelyResolvable { + wantUnresolvedRouteInfo := wantRouteInfo + wantUnresolvedRouteInfo.RemoteLinkAddress = "" + + if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }); err != tcpip.ErrWouldBlock { + t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, tcpip.ErrWouldBlock) } - if err != want { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want) + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) } - if i == 0 { - <-ch - continue + if !test.expectedSuccess { + return } - if addr != linkAddr2 { - t.Fatalf("got addr = %s, want = %s", addr, linkAddr2) + // At this point the neighbor table should be populated so the route + // should be immediately resolvable. + } + + if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }); err != nil { + t.Errorf("r.ResolvedFields(_): %s", err) + } + select { + case routeResolveRes := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected route to be immediately resolvable") } }) } }) } } + +func TestWritePacketsLinkResolution(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedWriteErr *tcpip.Error + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + } + + host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) + } + defer serverEP.Close() + + serverAddr := tcpip.FullAddress{Port: 1234} + if err := serverEP.Bind(serverAddr); err != nil { + t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err) + } + + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + data := []byte{1, 2} + var pkts stack.PacketBufferList + for _, d := range data { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: buffer.View([]byte{d}).ToVectorisedView(), + }) + pkt.TransportProtocolNumber = udp.ProtocolNumber + length := uint16(pkt.Size()) + udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: serverAddr.Port, + Length: length, + }) + xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + + pkts.PushBack(pkt) + } + + params := stack.NetworkHeaderParams{ + Protocol: udp.ProtocolNumber, + TTL: 64, + TOS: stack.DefaultTOS, + } + + if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil { + t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err) + } else if want := pkts.Len(); want != n { + t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want) + } + + var writer bytes.Buffer + count := 0 + for { + var rOpts tcpip.ReadOptions + res, err := serverEP.Read(&writer, rOpts) + if err != nil { + if err == tcpip.ErrWouldBlock { + // Should not have anymore bytes to read after we read the sent + // number of bytes. + if count == len(data) { + break + } + + <-serverCH + continue + } + + t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err) + } + count += res.Count + } + + if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want { + t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if diff := cmp.Diff(data, writer.Bytes()); diff != "" { + t.Errorf("read bytes mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3b13ba04d..761283b66 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { Port: localPort, }, } - n, err := sep.Write(tcpip.SlicePayload(data), wopts) + var r bytes.Reader + r.Reset(data) + n, err := sep.Write(&r, wopts) if err != nil { t.Fatalf("sep.Write(_, _): %s", err) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index ce7c16bd1..9cc12fa58 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) - if n, err := wep.ep.Write(data, writeOpts); err != nil { + data := []byte{byte(i), 2, 3, 4} + var r bytes.Reader + r.Reset(data) + if n, err := wep.ep.Write(&r, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index b222d2b05..35ee7437a 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -194,9 +194,11 @@ func TestLocalPing(t *testing.T) { return } - payload := tcpip.SlicePayload(test.icmpBuf(t)) + payload := test.icmpBuf(t) + var r bytes.Reader + r.Reset(payload) var wOpts tcpip.WriteOptions - if n, err := ep.Write(payload, wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) } else if n != int64(len(payload)) { t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) @@ -329,12 +331,14 @@ func TestLocalUDP(t *testing.T) { Port: 80, } - clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + clientPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(clientPayload) wOpts := tcpip.WriteOptions{ To: &serverAddr, } - if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr { t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) } else if subTest.expectedWriteErr != nil { // Nothing else to test if we expected not to be able to send the @@ -376,12 +380,14 @@ func TestLocalUDP(t *testing.T) { } } - serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + serverPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(serverPayload) wOpts := tcpip.WriteOptions{ To: &clientAddr, } - if n, err := server.Write(serverPayload, wOpts); err != nil { + if n, err := server.Write(&r, wOpts); err != nil { t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) } else if n != int64(len(serverPayload)) { t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 256e19296..e4bcd3120 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -69,8 +69,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int + mu sync.RWMutex `state:"nosave"` // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState @@ -94,11 +93,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.SetSendBufferSize(32*1024, false /* notify */) + + // Override with stack defaults. + var ss tcpip.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) + } return ep, nil } @@ -119,7 +124,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -313,11 +318,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc route = r } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } + var err *tcpip.Error switch e.NetProto { case header.IPv4ProtocolNumber: err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) @@ -362,11 +368,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSize - e.mu.Unlock() - return v, nil case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() @@ -578,14 +579,14 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) switch err { case nil: return true, nil diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 9d263c0ec..afe96998a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -69,6 +69,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) if e.state != stateBound && e.state != stateConnected { return diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c0d6fb442..d48877677 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -79,13 +79,11 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool - boundNIC tcpip.NICID + mu sync.RWMutex `state:"nosave"` + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -106,14 +104,13 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb netProto: netProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - ep.sndBufSizeMax = ss.Default + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -207,7 +204,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, tcpip.ErrInvalidOptionValue } @@ -320,24 +317,6 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := ep.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - ep.mu.Lock() - ep.sndBufSizeMax = v - ep.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -395,12 +374,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - ep.mu.Lock() - v := ep.sndBufSizeMax - ep.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: ep.rcvMu.Lock() v := ep.rcvBufSizeMax diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index e2fa96d17..4d98fb051 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -63,8 +63,8 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. ep.stack = stack.StackFromEnv + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ae743f75e..6c6d45188 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -76,12 +76,10 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - connected bool - bound bool + mu sync.RWMutex `state:"nosave"` + closed bool + connected bool + bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. route *stack.Route `state:"manual"` @@ -112,16 +110,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, associated: associated, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) e.ops.SetHeaderIncluded(!associated) + e.ops.SetSendBufferSize(32*1024, false /* notify */) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -138,7 +136,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } - if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { return nil, err } @@ -159,7 +157,7 @@ func (e *endpoint) Close() { return } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.rcvMu.Lock() defer e.rcvMu.Unlock() @@ -280,9 +278,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return 0, tcpip.ErrInvalidEndpointState } - payloadBytes, err := p.FullPayload() - if err != nil { - return 0, err + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, tcpip.ErrBadBuffer } // If this is an unassociated socket and callee provided a nonzero @@ -407,15 +405,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { route.Release() return err } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.RegisterNICID = nic } - // Save the route we've connected via. + if e.route != nil { + // If the endpoint was previously connected then release any previous route. + e.route.Release() + } e.route = route e.connected = true @@ -449,16 +450,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { defer e.mu.Unlock() // If a local address was specified, verify that it's valid. - if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { return err } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.RegisterNICID = addr.NIC e.BindNICID = addr.NIC } @@ -511,24 +512,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -570,12 +553,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 4a7e1c039..65c64d99f 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -69,6 +69,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) // If the endpoint is connected, re-connect. if e.connected { @@ -93,7 +94,7 @@ func (e *endpoint) Resume(s *stack.Stack) { } if e.associated { - if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7e81203ba..fcdd032c5 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -99,7 +99,6 @@ go_test( "//pkg/rand", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6921de0f1..e475c36f3 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -288,7 +288,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q } // Register new endpoint so that packets are routed to it. - if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { ep.mu.Unlock() ep.Close() @@ -692,7 +692,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er } // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { n.mu.Unlock() n.Close() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index f711cd4df..62954d7e4 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -895,43 +895,46 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn return err } -func (e *endpoint) handleWrite() *tcpip.Error { - // Move packets from send queue to send list. The queue is accessible - // from other goroutines and protected by the send mutex, while the send - // list is only accessible from the handler goroutine, so it needs no - // mutexes. +func (e *endpoint) handleWrite() { e.sndBufMu.Lock() + next := e.drainSendQueueLocked() + e.sndBufMu.Unlock() + + e.sendData(next) +} +// Move packets from send queue to send list. +// +// Precondition: e.sndBufMu must be locked. +func (e *endpoint) drainSendQueueLocked() *segment { first := e.sndQueue.Front() if first != nil { e.snd.writeList.PushBackList(&e.sndQueue) e.sndBufInQueue = 0 } + return first +} - e.sndBufMu.Unlock() - +// Precondition: e.mu must be locked. +func (e *endpoint) sendData(next *segment) { // Initialize the next segment to write if it's currently nil. if e.snd.writeNext == nil { - e.snd.writeNext = first + e.snd.writeNext = next } // Push out any new packets. e.snd.sendData() - - return nil } -func (e *endpoint) handleClose() *tcpip.Error { +func (e *endpoint) handleClose() { if !e.EndpointState().connected() { - return nil + return } // Drain the send queue. e.handleWrite() // Mark send side as closed. e.snd.closed = true - - return nil } // resetConnectionLocked puts the endpoint in an error state with the given @@ -1348,11 +1351,17 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }{ { w: &e.sndWaker, - f: e.handleWrite, + f: func() *tcpip.Error { + e.handleWrite() + return nil + }, }, { w: &e.sndCloseWaker, - f: e.handleClose, + f: func() *tcpip.Error { + e.handleClose() + return nil + }, }, { w: &closeWaker, diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 1d1b01a6c..809c88732 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -15,11 +15,11 @@ package tcp_test import ( + "strings" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -415,8 +415,10 @@ func testV4Accept(t *testing.T, c *context.Context) { t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) } + var r strings.Reader data := "Don't panic" - nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + r.Reset(data) + nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() tcp = header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ea509ac73..b6bd6d455 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -557,7 +557,6 @@ type endpoint struct { // When the send side is closed, the protocol goroutine is notified via // sndCloseWaker, and sndClosed is set to true. sndBufMu sync.Mutex `state:"nosave"` - sndBufSize int sndBufUsed int sndClosed bool sndBufInQueue seqnum.Size @@ -869,7 +868,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, - sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), keepalive: keepalive{ // Linux defaults. @@ -882,13 +880,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) + e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - e.sndBufSize = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -967,7 +966,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() - if e.sndClosed || e.sndBufUsed < e.sndBufSize { + sndBufSize := e.getSendBufferSize() + if e.sndClosed || e.sndBufUsed < sndBufSize { result |= waiter.EventOut } e.sndBufMu.Unlock() @@ -1087,7 +1087,7 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1161,7 +1161,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1499,7 +1499,8 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { return 0, tcpip.ErrClosedForSend } - avail := e.sndBufSize - e.sndBufUsed + sndBufSize := e.getSendBufferSize() + avail := sndBufSize - e.sndBufUsed if avail <= 0 { return 0, tcpip.ErrWouldBlock } @@ -1513,69 +1514,79 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // and opts.EndOfRecord are also ignored. e.LockUser() - e.sndBufMu.Lock() - - avail, err := e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, err - } - - // We can release locks while copying data. - // - // This is not possible if atomic is set, because we can't allow the - // available buffer space to be consumed by some other caller while we - // are copying data in. - if !opts.Atomic { - e.sndBufMu.Unlock() - e.UnlockUser() - } - - // Fetch data. - v, perr := p.Payload(avail) - if perr != nil || len(v) == 0 { - // Note that perr may be nil if len(v) == 0. - if opts.Atomic { - e.sndBufMu.Unlock() - e.UnlockUser() - } - return 0, perr - } + defer e.UnlockUser() - if !opts.Atomic { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - e.LockUser() + nextSeg, n, err := func() (*segment, int, *tcpip.Error) { e.sndBufMu.Lock() + defer e.sndBufMu.Unlock() + avail, err := e.isEndpointWritableLocked() if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() - return 0, err + return nil, 0, err } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + v, err := func() ([]byte, *tcpip.Error) { + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + defer e.sndBufMu.Lock() + + e.UnlockUser() + defer e.LockUser() + } + + // Fetch data. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return nil, nil + } + v := make([]byte, avail) + if _, err := io.ReadFull(p, v); err != nil { + return nil, tcpip.ErrBadBuffer + } + return v, nil + }() + if len(v) == 0 || err != nil { + return nil, 0, err + } + + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.stats.WriteErrors.WriteClosed.Increment() + return nil, 0, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } - } - // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() + // Add data to the send queue. + s := newOutgoingSegment(e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) - // Do the work inline. - e.handleWrite() - e.UnlockUser() - return int64(len(v)), nil + return e.drainSendQueueLocked(), len(v), nil + }() + if err != nil { + return 0, err + } + e.sendData(nextSeg) + return int64(n), nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1682,6 +1693,14 @@ func (e *endpoint) OnCorkOptionSet(v bool) { } } +func (e *endpoint) getSendBufferSize() int { + sndBufSize, err := e.ops.GetSendBufferSize() + if err != nil { + panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err)) + } + return int(sndBufSize) +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -1775,31 +1794,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss tcpip.TCPSendBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) - } - - if v > ss.Max { - v = ss.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < ss.Min { - v = ss.Min - } - } else { - v = math.MaxInt32 - } - - e.sndBufMu.Lock() - e.sndBufSize = v - e.sndBufMu.Unlock() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1985,12 +1979,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - v := e.sndBufSize - e.sndBufMu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvListMu.Lock() v := e.rcvBufSize @@ -2190,7 +2178,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -2278,7 +2266,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) if err == tcpip.ErrPortInUse { return false, nil @@ -2482,7 +2470,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } @@ -2604,7 +2592,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2739,13 +2727,14 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt // updateSndBufferUsage is called by the protocol goroutine when room opens up // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { + sendBufferSize := e.getSendBufferSize() e.sndBufMu.Lock() - notify := e.sndBufUsed >= e.sndBufSize>>1 + notify := e.sndBufUsed >= sendBufferSize>>1 e.sndBufUsed -= v - // We only notify when there is half the sndBufSize available after + // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < e.sndBufSize>>1 + notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 e.sndBufMu.Unlock() if notify { @@ -2957,8 +2946,9 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() // Copy endpoint send state. + sndBufSize := e.getSendBufferSize() e.sndBufMu.Lock() - s.SndBufSize = e.sndBufSize + s.SndBufSize = sndBufSize s.SndBufUsed = e.sndBufUsed s.SndClosed = e.sndClosed s.SndBufInQueue = e.sndBufInQueue @@ -3103,3 +3093,17 @@ func (e *endpoint) Wait() { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// GetTCPSendBufferLimits is used to get send buffer size limits for TCP. +func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { + var ss tcpip.TCPSendBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.SendBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ba67176b5..4a01c81b4 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,7 +55,9 @@ func (e *endpoint) beforeSave() { case epState.connected() || epState.handshake(): if !e.route.HasSaveRestoreCapability() { if !e.route.HasDisconncetOkCapability() { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) + panic(&tcpip.ErrSaveRejection{ + Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), + }) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() @@ -179,14 +181,16 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + sendBufferSize := e.getSendBufferSize() + if sendBufferSize < ss.Min || sendBufferSize > ss.Max { + panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max)) } } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index f7aaee23f..ced3a9c58 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -21,13 +21,13 @@ package tcp_test import ( + "bytes" "fmt" "math" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" @@ -42,14 +42,16 @@ func TestFastRecovery(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -207,14 +209,16 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -249,14 +253,16 @@ func TestCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -353,15 +359,16 @@ func TestCubicCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -462,19 +469,20 @@ func TestRetransmit(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in two shots. Packets will only be written at the // MTU size though. - half := data[:len(data)/2] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data[:len(data)/2]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } - half = data[len(data)/2:] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + r.Reset(data[len(data)/2:]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 342eb5eb8..af915203b 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -15,11 +15,11 @@ package tcp_test import ( + "bytes" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -61,14 +61,16 @@ func TestRACKUpdate(t *testing.T) { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(maxPayload) + data := make([]byte, maxPayload) for i := range data { data[i] = byte(i) } // Write the data. xmitTime = time.Now() - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -114,13 +116,15 @@ func TestRACKDetectReorder(t *testing.T) { }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(ackNumToVerify * maxPayload) + data := make([]byte, ackNumToVerify*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -141,17 +145,19 @@ func TestRACKDetectReorder(t *testing.T) { <-probeDone } -func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View { +func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(numPackets * maxPayload) + data := make([]byte, numPackets*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 6635bb815..5024bc925 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -15,6 +15,7 @@ package tcp_test import ( + "bytes" "fmt" "log" "reflect" @@ -22,7 +23,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -395,14 +395,16 @@ func TestSACKRecovery(t *testing.T) { createConnectedWithSACKAndTS(c) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 93683b921..87ff2b909 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "math" + "strings" "testing" "time" @@ -26,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -1347,10 +1347,9 @@ func TestTOSV4(t *testing.T) { testV4Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1396,10 +1395,9 @@ func TestTrafficClassV6(t *testing.T) { testV6Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2176,10 +2174,9 @@ func TestSimpleSend(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2217,10 +2214,9 @@ func TestZeroWindowSend(t *testing.T) { c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2285,10 +2281,9 @@ func TestScaledWindowConnect(t *testing.T) { }) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2317,10 +2312,9 @@ func TestNonScaledWindowConnect(t *testing.T) { c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2391,10 +2385,9 @@ func TestScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2465,10 +2458,9 @@ func TestNonScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2632,9 +2624,10 @@ func TestSegmentMerging(t *testing.T) { // Send tcp.InitialCwnd number of segments to fill up // InitialWindow but don't ACK. That should prevent // anymore packets from going out. + var r bytes.Reader for i := 0; i < tcp.InitialCwnd; i++ { - view := buffer.NewViewFromBytes([]byte{0}) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset([]byte{0}) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2644,8 +2637,8 @@ func TestSegmentMerging(t *testing.T) { var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2714,8 +2707,9 @@ func TestDelay(t *testing.T) { var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2761,8 +2755,9 @@ func TestUndelay(t *testing.T) { allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2845,8 +2840,9 @@ func TestMSSNotDelayed(t *testing.T) { allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2894,10 +2890,9 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { data[i] = byte(i) } - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3328,8 +3323,9 @@ func TestSendOnResetConnection(t *testing.T) { time.Sleep(1 * time.Second) // Try to write. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { + var r bytes.Reader + r.Reset(make([]byte, 10)) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) } } @@ -3352,7 +3348,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventHUp) defer c.WQ.EventUnregister(&waitEntry) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3409,7 +3407,9 @@ func TestMaxRTO(t *testing.T) { c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3458,7 +3458,9 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) } - if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(make([]byte, tc.size)) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } pkt := c.GetPacket() @@ -3595,8 +3597,10 @@ func TestFinWithNoPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3667,9 +3671,11 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // Write enough segments to fill the congestion window before ACK'ing // any of them. - view := buffer.NewView(10) + view := make([]byte, 10) + var r bytes.Reader for i := tcp.InitialCwnd; i > 0; i-- { - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } } @@ -3754,8 +3760,10 @@ func TestFinWithPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3781,7 +3789,8 @@ func TestFinWithPendingData(t *testing.T) { }) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3841,8 +3850,10 @@ func TestFinWithPartialAck(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3879,7 +3890,8 @@ func TestFinWithPartialAck(t *testing.T) { ) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3985,8 +3997,10 @@ func scaledSendWindow(t *testing.T, scale uint8) { }) // Send some data. Check that it's capped by the window size. - view := buffer.NewView(65535) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 65535) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4351,9 +4365,9 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + s, err := ep.SocketOptions().GetSendBufferSize() if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) + t.Fatalf("GetSendBufferSize failed: %s", err) } if int(s) != v { @@ -4459,9 +4473,7 @@ func TestMinMaxBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(149, true) checkSendBufferSize(t, ep, 300) @@ -4473,9 +4485,7 @@ func TestMinMaxBufferSizes(t *testing.T) { // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) @@ -4610,9 +4620,9 @@ func TestSelfConnect(t *testing.T) { // Write something. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4785,12 +4795,13 @@ func TestPathMTUDiscovery(t *testing.T) { // Send 3200 bytes of data. const writeSize = 3200 - data := buffer.NewView(writeSize) + data := make([]byte, writeSize) for i := range data { data[i] = byte(i) } - - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5078,8 +5089,10 @@ func TestKeepalive(t *testing.T) { // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5358,7 +5371,9 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { @@ -5674,7 +5689,9 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() tcp = header.TCP(header.IPv4(pkt).Payload()) if string(tcp.Payload()) != data { @@ -5908,7 +5925,9 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil { + var r strings.Reader + r.Reset(data) + if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7103,10 +7122,10 @@ func TestTCPCloseWithData(t *testing.T) { // Now write a few bytes and then close the endpoint. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7204,8 +7223,10 @@ func TestTCPUserTimeout(t *testing.T) { } // Send some data and wait before ACKing it. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index b65091c3c..5a9745ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -22,7 +22,6 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -152,10 +151,10 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS // Now send some data and validate that timestamp is echoed correctly in the ACK. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } @@ -215,10 +214,10 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd // Now send some data with the accepted connection endpoint and validate // that no timestamp option is sent in the TCP segment. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9f9b3d510..4988ba29b 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -97,9 +97,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int + mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. state EndpointState @@ -176,18 +174,18 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) e.ops.SetMulticastLoop(true) + e.ops.SetSendBufferSize(32*1024, false /* notify */) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -246,7 +244,7 @@ func (e *endpoint) Close() { switch e.EndpointState() { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -514,9 +512,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return 0, tcpip.ErrBroadcastDisabled } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. @@ -632,25 +630,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvBufSizeMax = v e.mu.Unlock() return nil - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) - } - - if v < ss.Min { - v = ss.Min - } - if v > ss.Max { - v = ss.Max - } - - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -811,12 +790,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax @@ -935,7 +908,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, } - id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { return err } @@ -950,7 +923,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.setEndpointState(StateInitial) } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() @@ -1023,7 +996,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { oldPortFlags := e.boundPortFlags - id, btd, err := e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(netProtos, id) if err != nil { r.Release() return err @@ -1031,11 +1004,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) } e.ID = id e.boundBindToDevice = btd + if e.route != nil { + // If the endpoint was already connected then make sure we release the + // previous route. + e.route.Release() + } e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID @@ -1093,7 +1071,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) @@ -1104,7 +1082,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} @@ -1148,7 +1126,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, btd, err := e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(netProtos, id) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 13b72dc88..feb53b553 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -91,6 +91,7 @@ func (e *endpoint) Resume(s *stack.Stack) { defer e.mu.Unlock() e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { @@ -131,7 +132,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // pass it to the reservation machinery. id := e.ID e.ID.LocalPort = 0 - e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 49e673d58..aae794506 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -77,7 +77,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { + if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4e2123fe9..c4794e876 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -966,8 +966,9 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - payload := buffer.View(newPayload()) - _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + var r bytes.Reader + r.Reset(newPayload()) + _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) c.checkEndpointWriteStats(1, epstats, gotErr) @@ -1007,8 +1008,10 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, } } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("Write failed: %s", err) } @@ -1183,8 +1186,10 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { writeOpts := tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) } @@ -2497,7 +2502,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } defer ep.Close() - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + var r bytes.Reader + data := []byte{1, 2, 3, 4} to := tcpip.FullAddress{ Addr: test.remoteAddr, Port: 80, @@ -2508,19 +2514,22 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { expectedErrWithoutBcastOpt = nil } - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != expectedErrWithoutBcastOpt { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) } ep.SocketOptions().SetBroadcast(true) - if n, err := ep.Write(data, opts); err != nil { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != nil { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) } ep.SocketOptions().SetBroadcast(false) - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != expectedErrWithoutBcastOpt { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) } }) diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 79db8895b..dc2571154 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -517,28 +517,29 @@ func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, er // Reader returns an io.Reader that reads from s. Reads beyond the end of s // return io.EOF. The preconditions that apply to s.CopyIn also apply to the // returned io.Reader.Read. -func (s IOSequence) Reader(ctx context.Context) io.Reader { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Reader(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // Writer returns an io.Writer that writes to s. Writes beyond the end of s // return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also // apply to the returned io.Writer.Write. -func (s IOSequence) Writer(ctx context.Context) io.Writer { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Writer(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when // attempting to write beyond the end of the IOSequence. var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence") -type ioSequenceReadWriter struct { +// IOSequenceReadWriter implements io.Reader and io.Writer for an IOSequence. +type IOSequenceReadWriter struct { ctx context.Context s IOSequence } // Read implements io.Reader.Read. -func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { +func (rw *IOSequenceReadWriter) Read(dst []byte) (int, error) { n, err := rw.s.CopyIn(rw.ctx, dst) rw.s = rw.s.DropFirst(n) if err == nil && rw.s.NumBytes() == 0 { @@ -547,8 +548,13 @@ func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { return n, err } +// Len implements tcpip.Payloader. +func (rw *IOSequenceReadWriter) Len() int { + return int(rw.s.NumBytes()) +} + // Write implements io.Writer.Write. -func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) { +func (rw *IOSequenceReadWriter) Write(src []byte) (int, error) { n, err := rw.s.CopyOut(rw.ctx, src) rw.s = rw.s.DropFirst(n) if err == nil && n < len(src) { |