diff options
Diffstat (limited to 'pkg')
166 files changed, 5443 insertions, 3176 deletions
diff --git a/pkg/coverage/BUILD b/pkg/coverage/BUILD index 5df7ca831..a198e8028 100644 --- a/pkg/coverage/BUILD +++ b/pkg/coverage/BUILD @@ -7,7 +7,6 @@ go_library( srcs = ["coverage.go"], visibility = ["//:sandbox"], deps = [ - "//pkg/log", "//pkg/sync", "//pkg/usermem", "@io_bazel_rules_go//go/tools/coverdata", diff --git a/pkg/fdchannel/fdchannel_unsafe.go b/pkg/fdchannel/fdchannel_unsafe.go index 367235be5..b253a8fdd 100644 --- a/pkg/fdchannel/fdchannel_unsafe.go +++ b/pkg/fdchannel/fdchannel_unsafe.go @@ -21,7 +21,6 @@ package fdchannel import ( "fmt" "reflect" - "sync/atomic" "syscall" "unsafe" ) @@ -41,7 +40,7 @@ func NewConnectedSockets() ([2]int, error) { // // Endpoint is not copyable or movable by value. type Endpoint struct { - sockfd int32 // accessed using atomic memory operations + sockfd int32 msghdr syscall.Msghdr cmsg *syscall.Cmsghdr // followed by sizeofInt32 bytes of data } @@ -54,10 +53,10 @@ func (ep *Endpoint) Init(sockfd int) { // sendmsg+recvmsg for a zero-length datagram is slightly faster than // sendmsg+recvmsg for a single byte over a stream socket. cmsgSlice := make([]byte, syscall.CmsgSpace(sizeofInt32)) - cmsgReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&cmsgSlice)) + cmsgReflect := (*reflect.SliceHeader)(unsafe.Pointer(&cmsgSlice)) ep.sockfd = int32(sockfd) - ep.msghdr.Control = (*byte)((unsafe.Pointer)(cmsgReflect.Data)) - ep.cmsg = (*syscall.Cmsghdr)((unsafe.Pointer)(cmsgReflect.Data)) + ep.msghdr.Control = (*byte)(unsafe.Pointer(cmsgReflect.Data)) + ep.cmsg = (*syscall.Cmsghdr)(unsafe.Pointer(cmsgReflect.Data)) // ep.msghdr.Controllen and ep.cmsg.* are mutated by recvmsg(2), so they're // set before calling sendmsg/recvmsg. } @@ -73,12 +72,8 @@ func NewEndpoint(sockfd int) *Endpoint { // Destroy releases resources owned by ep. No other Endpoint methods may be // called after Destroy. func (ep *Endpoint) Destroy() { - // These need not use sync/atomic since there must not be any concurrent - // calls to Endpoint methods. - if ep.sockfd >= 0 { - syscall.Close(int(ep.sockfd)) - ep.sockfd = -1 - } + syscall.Close(int(ep.sockfd)) + ep.sockfd = -1 } // Shutdown causes concurrent and future calls to ep.SendFD(), ep.RecvFD(), and @@ -88,10 +83,7 @@ func (ep *Endpoint) Destroy() { // Shutdown is the only Endpoint method that may be called concurrently with // other methods. func (ep *Endpoint) Shutdown() { - if sockfd := int(atomic.SwapInt32(&ep.sockfd, -1)); sockfd >= 0 { - syscall.Shutdown(sockfd, syscall.SHUT_RDWR) - syscall.Close(sockfd) - } + syscall.Shutdown(int(ep.sockfd), syscall.SHUT_RDWR) } // SendFD sends the open file description represented by the given file @@ -103,7 +95,7 @@ func (ep *Endpoint) SendFD(fd int) error { ep.cmsg.SetLen(cmsgLen) *ep.cmsgData() = int32(fd) ep.msghdr.SetControllen(cmsgLen) - _, _, e := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(atomic.LoadInt32(&ep.sockfd)), uintptr((unsafe.Pointer)(&ep.msghdr)), 0) + _, _, e := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), 0) if e != 0 { return e } @@ -113,7 +105,7 @@ func (ep *Endpoint) SendFD(fd int) error { // RecvFD receives an open file description from the connected Endpoint and // returns a file descriptor representing it, owned by the caller. func (ep *Endpoint) RecvFD() (int, error) { - return ep.recvFD(0) + return ep.recvFD(false) } // RecvFDNonblock receives an open file description from the connected Endpoint @@ -121,13 +113,18 @@ func (ep *Endpoint) RecvFD() (int, error) { // are no pending receivable open file descriptions, RecvFDNonblock returns // (<unspecified>, EAGAIN or EWOULDBLOCK). func (ep *Endpoint) RecvFDNonblock() (int, error) { - return ep.recvFD(syscall.MSG_DONTWAIT) + return ep.recvFD(true) } -func (ep *Endpoint) recvFD(flags uintptr) (int, error) { +func (ep *Endpoint) recvFD(nonblock bool) (int, error) { cmsgLen := syscall.CmsgLen(sizeofInt32) ep.msghdr.SetControllen(cmsgLen) - _, _, e := syscall.Syscall(syscall.SYS_RECVMSG, uintptr(atomic.LoadInt32(&ep.sockfd)), uintptr((unsafe.Pointer)(&ep.msghdr)), flags|syscall.MSG_TRUNC) + var e syscall.Errno + if nonblock { + _, _, e = syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), syscall.MSG_TRUNC|syscall.MSG_DONTWAIT) + } else { + _, _, e = syscall.Syscall(syscall.SYS_RECVMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), syscall.MSG_TRUNC) + } if e != 0 { return -1, e } @@ -142,5 +139,5 @@ func (ep *Endpoint) recvFD(flags uintptr) (int, error) { func (ep *Endpoint) cmsgData() *int32 { // syscall.CmsgLen(0) == syscall.cmsgAlignOf(syscall.SizeofCmsghdr) - return (*int32)((unsafe.Pointer)(uintptr((unsafe.Pointer)(ep.cmsg)) + uintptr(syscall.CmsgLen(0)))) + return (*int32)(unsafe.Pointer(uintptr(unsafe.Pointer(ep.cmsg)) + uintptr(syscall.CmsgLen(0)))) } diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index e7c3a3a0b..2e8452a02 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -40,17 +40,41 @@ func (ep *Endpoint) ctrlInit(opts ...EndpointOption) error { return nil } -type ctrlHandshakeRequest struct{} - -type ctrlHandshakeResponse struct{} - func (ep *Endpoint) ctrlConnect() error { if err := ep.enterFutexWait(); err != nil { return err } - _, err := ep.futexConnect(&ctrlHandshakeRequest{}) - ep.exitFutexWait() - return err + defer ep.exitFutexWait() + + // Write the connection request. + w := ep.NewWriter() + if err := json.NewEncoder(w).Encode(struct{}{}); err != nil { + return fmt.Errorf("error writing connection request: %v", err) + } + *ep.dataLen() = w.Len() + + // Exchange control with the server. + if err := ep.futexSetPeerActive(); err != nil { + return err + } + if err := ep.futexWakePeer(); err != nil { + return err + } + if err := ep.futexWaitUntilActive(); err != nil { + return err + } + + // Read the connection response. + var resp struct{} + respLen := atomic.LoadUint32(ep.dataLen()) + if respLen > ep.dataCap { + return fmt.Errorf("invalid connection response length %d (maximum %d)", respLen, ep.dataCap) + } + if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil { + return fmt.Errorf("error reading connection response: %v", err) + } + + return nil } func (ep *Endpoint) ctrlWaitFirst() error { @@ -59,52 +83,61 @@ func (ep *Endpoint) ctrlWaitFirst() error { } defer ep.exitFutexWait() - // Wait for the handshake request. - if err := ep.futexSwitchFromPeer(); err != nil { + // Wait for the connection request. + if err := ep.futexWaitUntilActive(); err != nil { return err } - // Read the handshake request. + // Read the connection request. reqLen := atomic.LoadUint32(ep.dataLen()) if reqLen > ep.dataCap { - return fmt.Errorf("invalid handshake request length %d (maximum %d)", reqLen, ep.dataCap) + return fmt.Errorf("invalid connection request length %d (maximum %d)", reqLen, ep.dataCap) } - var req ctrlHandshakeRequest + var req struct{} if err := json.NewDecoder(ep.NewReader(reqLen)).Decode(&req); err != nil { - return fmt.Errorf("error reading handshake request: %v", err) + return fmt.Errorf("error reading connection request: %v", err) } - // Write the handshake response. + // Write the connection response. w := ep.NewWriter() - if err := json.NewEncoder(w).Encode(ctrlHandshakeResponse{}); err != nil { - return fmt.Errorf("error writing handshake response: %v", err) + if err := json.NewEncoder(w).Encode(struct{}{}); err != nil { + return fmt.Errorf("error writing connection response: %v", err) } *ep.dataLen() = w.Len() // Return control to the client. raceBecomeInactive() - if err := ep.futexSwitchToPeer(); err != nil { + if err := ep.futexSetPeerActive(); err != nil { + return err + } + if err := ep.futexWakePeer(); err != nil { return err } - // Wait for the first non-handshake message. - return ep.futexSwitchFromPeer() + // Wait for the first non-connection message. + return ep.futexWaitUntilActive() } func (ep *Endpoint) ctrlRoundTrip() error { - if err := ep.futexSwitchToPeer(); err != nil { + if err := ep.enterFutexWait(); err != nil { return err } - if err := ep.enterFutexWait(); err != nil { + defer ep.exitFutexWait() + + if err := ep.futexSetPeerActive(); err != nil { return err } - err := ep.futexSwitchFromPeer() - ep.exitFutexWait() - return err + if err := ep.futexWakePeer(); err != nil { + return err + } + return ep.futexWaitUntilActive() } func (ep *Endpoint) ctrlWakeLast() error { - return ep.futexSwitchToPeer() + if err := ep.futexSetPeerActive(); err != nil { + return err + } + return ep.futexWakePeer() } func (ep *Endpoint) enterFutexWait() error { diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index ac974b232..580bf23a4 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -41,11 +41,11 @@ const ( ) func (ep *Endpoint) connState() *uint32 { - return (*uint32)((unsafe.Pointer)(ep.packet)) + return (*uint32)(unsafe.Pointer(ep.packet)) } func (ep *Endpoint) dataLen() *uint32 { - return (*uint32)((unsafe.Pointer)(ep.packet + 4)) + return (*uint32)(unsafe.Pointer(ep.packet + 4)) } // Data returns the datagram part of ep's packet window as a byte slice. @@ -63,7 +63,7 @@ func (ep *Endpoint) dataLen() *uint32 { // all. func (ep *Endpoint) Data() []byte { var bs []byte - bsReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&bs)) + bsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) bsReflect.Data = ep.packet + PacketHeaderBytes bsReflect.Len = int(ep.dataCap) bsReflect.Cap = int(ep.dataCap) @@ -76,12 +76,12 @@ var ioSync int64 func raceBecomeActive() { if sync.RaceEnabled { - sync.RaceAcquire((unsafe.Pointer)(&ioSync)) + sync.RaceAcquire(unsafe.Pointer(&ioSync)) } } func raceBecomeInactive() { if sync.RaceEnabled { - sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + sync.RaceReleaseMerge(unsafe.Pointer(&ioSync)) } } diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index 168c1ccff..0e559ee16 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -17,7 +17,6 @@ package flipcall import ( - "encoding/json" "fmt" "runtime" "sync/atomic" @@ -26,55 +25,26 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeResponse, error) { - var resp ctrlHandshakeResponse - - // Write the handshake request. - w := ep.NewWriter() - if err := json.NewEncoder(w).Encode(req); err != nil { - return resp, fmt.Errorf("error writing handshake request: %v", err) - } - *ep.dataLen() = w.Len() - - // Exchange control with the server. - if err := ep.futexSwitchToPeer(); err != nil { - return resp, err +func (ep *Endpoint) futexSetPeerActive() error { + if atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { + return nil } - if err := ep.futexSwitchFromPeer(); err != nil { - return resp, err + switch cs := atomic.LoadUint32(ep.connState()); cs { + case csShutdown: + return ShutdownError{} + default: + return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) } - - // Read the handshake response. - respLen := atomic.LoadUint32(ep.dataLen()) - if respLen > ep.dataCap { - return resp, fmt.Errorf("invalid handshake response length %d (maximum %d)", respLen, ep.dataCap) - } - if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil { - return resp, fmt.Errorf("error reading handshake response: %v", err) - } - - return resp, nil } -func (ep *Endpoint) futexSwitchToPeer() error { - // Update connection state to indicate that the peer should be active. - if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { - switch cs := atomic.LoadUint32(ep.connState()); cs { - case csShutdown: - return ShutdownError{} - default: - return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) - } - } - - // Wake the peer's Endpoint.futexSwitchFromPeer(). +func (ep *Endpoint) futexWakePeer() error { if err := ep.futexWakeConnState(1); err != nil { return fmt.Errorf("failed to FUTEX_WAKE peer Endpoint: %v", err) } return nil } -func (ep *Endpoint) futexSwitchFromPeer() error { +func (ep *Endpoint) futexWaitUntilActive() error { for { switch cs := atomic.LoadUint32(ep.connState()); cs { case ep.activeState: diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD index 7a82631c5..d855b702c 100644 --- a/pkg/goid/BUILD +++ b/pkg/goid/BUILD @@ -8,8 +8,6 @@ go_library( "goid.go", "goid_amd64.s", "goid_arm64.s", - "goid_race.go", - "goid_unsafe.go", ], visibility = ["//visibility:public"], ) @@ -18,7 +16,6 @@ go_test( name = "goid_test", size = "small", srcs = [ - "empty_test.go", "goid_test.go", ], library = ":goid", diff --git a/pkg/goid/empty_test.go b/pkg/goid/empty_test.go deleted file mode 100644 index c0a4b17ab..000000000 --- a/pkg/goid/empty_test.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !race - -package goid - -import "testing" - -// TestNothing exists to make the build system happy. -func TestNothing(t *testing.T) {} diff --git a/pkg/goid/goid.go b/pkg/goid/goid.go index 39df30031..17c384cb0 100644 --- a/pkg/goid/goid.go +++ b/pkg/goid/goid.go @@ -12,13 +12,61 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !race +// +build go1.12 +// +build !go1.17 -// Package goid provides access to the ID of the current goroutine in -// race/gotsan builds. +// Check type signatures when updating Go version. + +// Package goid provides the Get function. package goid // Get returns the ID of the current goroutine. func Get() int64 { - panic("unimplemented for non-race builds") + return getg().goid +} + +// Structs from Go runtime. These may change in the future and require +// updating. These structs are currently the same on both AMD64 and ARM64, +// but may diverge in the future. + +type stack struct { + lo uintptr + hi uintptr +} + +type gobuf struct { + sp uintptr + pc uintptr + g uintptr + ctxt uintptr + ret uint64 + lr uintptr + bp uintptr } + +type g struct { + stack stack + stackguard0 uintptr + stackguard1 uintptr + + _panic uintptr + _defer uintptr + m uintptr + sched gobuf + syscallsp uintptr + syscallpc uintptr + stktopsp uintptr + param uintptr + atomicstatus uint32 + stackLock uint32 + goid int64 + + // More fields... + // + // We only use goid and the fields before it are only listed to + // calculate the correct offset. +} + +// Defined in assembly. This can't use go:linkname since runtime.getg() isn't a +// real function, it's a compiler intrinsic. +func getg() *g diff --git a/pkg/goid/goid_race.go b/pkg/goid/goid_race.go deleted file mode 100644 index 1766beaee..000000000 --- a/pkg/goid/goid_race.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Only available in race/gotsan builds. -// +build race - -// Package goid provides access to the ID of the current goroutine in -// race/gotsan builds. -package goid - -// Get returns the ID of the current goroutine. -func Get() int64 { - return goid() -} diff --git a/pkg/goid/goid_test.go b/pkg/goid/goid_test.go index 31970ce79..54be11d63 100644 --- a/pkg/goid/goid_test.go +++ b/pkg/goid/goid_test.go @@ -12,63 +12,70 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build race - package goid import ( "runtime" "sync" "testing" + "time" ) -func TestInitialGoID(t *testing.T) { - const max = 10000 - if id := goid(); id < 0 || id > max { - t.Errorf("got goid = %d, want 0 < goid <= %d", id, max) - } -} +func TestUniquenessAndConsistency(t *testing.T) { + const ( + numGoroutines = 5000 -// TestGoIDSquence verifies that goid returns values which could plausibly be -// goroutine IDs. If this test breaks or becomes flaky, the structs in -// goid_unsafe.go may need to be updated. -func TestGoIDSquence(t *testing.T) { - // Goroutine IDs are cached by each P. - runtime.GOMAXPROCS(1) + // maxID is not an intrinsic property of goroutine IDs; it is only a + // property of how the Go runtime currently assigns them. Future + // changes to the Go runtime may require that maxID be raised, or that + // assertions regarding it be removed entirely. + maxID = numGoroutines + 1000 + ) - // Fill any holes in lower range. - for i := 0; i < 50; i++ { - var wg sync.WaitGroup - wg.Add(1) + var ( + goidsMu sync.Mutex + goids = make(map[int64]struct{}) + checkedWG sync.WaitGroup + exitCh = make(chan struct{}) + ) + for i := 0; i < numGoroutines; i++ { + checkedWG.Add(1) go func() { - wg.Done() - - // Leak the goroutine to prevent the ID from being - // reused. - select {} - }() - wg.Wait() - } - - id := goid() - for i := 0; i < 100; i++ { - var ( - newID int64 - wg sync.WaitGroup - ) - wg.Add(1) - go func() { - newID = goid() - wg.Done() - - // Leak the goroutine to prevent the ID from being - // reused. - select {} + id := Get() + if id > maxID { + t.Errorf("observed unexpectedly large goroutine ID %d", id) + } + goidsMu.Lock() + if _, dup := goids[id]; dup { + t.Errorf("observed duplicate goroutine ID %d", id) + } + goids[id] = struct{}{} + goidsMu.Unlock() + checkedWG.Done() + for { + if curID := Get(); curID != id { + t.Errorf("goroutine ID changed from %d to %d", id, curID) + // Don't spam logs by repeating the check; wait quietly for + // the test to finish. + <-exitCh + return + } + // Check if the test is over. + select { + case <-exitCh: + return + default: + } + // Yield to other goroutines, and possibly migrate to another P. + runtime.Gosched() + } }() - wg.Wait() - if max := id + 100; newID <= id || newID > max { - t.Errorf("unexpected goroutine ID pattern, got goid = %d, want %d < goid <= %d (previous = %d)", newID, id, max, id) - } - id = newID } + // Wait for all goroutines to perform uniqueness checks. + checkedWG.Wait() + // Wait for an additional second to allow goroutines to spin checking for + // ID consistency. + time.Sleep(time.Second) + // Request that all goroutines exit. + close(exitCh) } diff --git a/pkg/goid/goid_unsafe.go b/pkg/goid/goid_unsafe.go deleted file mode 100644 index ded8004dd..000000000 --- a/pkg/goid/goid_unsafe.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package goid - -// Structs from Go runtime. These may change in the future and require -// updating. These structs are currently the same on both AMD64 and ARM64, -// but may diverge in the future. - -type stack struct { - lo uintptr - hi uintptr -} - -type gobuf struct { - sp uintptr - pc uintptr - g uintptr - ctxt uintptr - ret uint64 - lr uintptr - bp uintptr -} - -type g struct { - stack stack - stackguard0 uintptr - stackguard1 uintptr - - _panic uintptr - _defer uintptr - m uintptr - sched gobuf - syscallsp uintptr - syscallpc uintptr - stktopsp uintptr - param uintptr - atomicstatus uint32 - stackLock uint32 - goid int64 - - // More fields... - // - // We only use goid and the fields before it are only listed to - // calculate the correct offset. -} - -func getg() *g - -// goid returns the ID of the current goroutine. -func goid() int64 { - return getg().goid -} diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index e0a9e56c5..6acee90ef 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -19,6 +19,7 @@ import ( "bytes" "crypto/sha256" "crypto/sha512" + "encoding/gob" "fmt" "io" @@ -151,11 +152,15 @@ type VerityDescriptor struct { Mode uint32 UID uint32 GID uint32 + Children map[string]struct{} RootHash []byte } func (d *VerityDescriptor) String() string { - return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, d.RootHash) + b := new(bytes.Buffer) + e := gob.NewEncoder(b) + e.Encode(d.Children) + return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, Children: %v, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, b.Bytes(), d.RootHash) } // verify generates a hash from d, and compares it with expected. @@ -202,6 +207,9 @@ type GenerateParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // Children is a map of children names for a directory. It should be + // empty for a regular file. + Children map[string]struct{} // HashAlgorithms is the algorithms used to hash data. HashAlgorithms int // TreeReader is a reader for the Merkle tree. @@ -294,6 +302,7 @@ func Generate(params *GenerateParams) ([]byte, error) { Mode: params.Mode, UID: params.UID, GID: params.GID, + Children: params.Children, RootHash: root, } return hashData([]byte(descriptor.String()), params.HashAlgorithms) @@ -318,6 +327,9 @@ type VerifyParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // Children is a map of children names for a directory. It should be + // empty for a regular file. + Children map[string]struct{} // HashAlgorithms is the algorithms used to hash data. HashAlgorithms int // ReadOffset is the offset of the data range to be verified. @@ -348,6 +360,7 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error { Mode: params.Mode, UID: params.UID, GID: params.GID, + Children: params.Children, RootHash: root, } return descriptor.verify(params.Expected, params.HashAlgorithms) @@ -409,6 +422,7 @@ func Verify(params *VerifyParams) (int64, error) { Mode: params.Mode, UID: params.UID, GID: params.GID, + Children: params.Children, } if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil { return 0, err diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index 405204d94..66ddf09e6 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -16,6 +16,7 @@ package merkletree import ( "bytes" + "errors" "fmt" "io" "math/rand" @@ -28,6 +29,7 @@ import ( func TestLayout(t *testing.T) { testCases := []struct { + name string dataSize int64 hashAlgorithms int dataAndTreeInSameFile bool @@ -35,6 +37,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset []int64 }{ { + name: "SmallSizeSHA256SeparateFile", dataSize: 100, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, @@ -42,6 +45,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{0}, }, { + name: "SmallSizeSHA512SeparateFile", dataSize: 100, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, @@ -49,6 +53,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{0}, }, { + name: "SmallSizeSHA256SameFile", dataSize: 100, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, @@ -56,6 +61,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{usermem.PageSize}, }, { + name: "SmallSizeSHA512SameFile", dataSize: 100, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, @@ -63,6 +69,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{usermem.PageSize}, }, { + name: "MiddleSizeSHA256SeparateFile", dataSize: 1000000, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, @@ -70,6 +77,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, }, { + name: "MiddleSizeSHA512SeparateFile", dataSize: 1000000, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, @@ -77,6 +85,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize}, }, { + name: "MiddleSizeSHA256SameFile", dataSize: 1000000, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, @@ -84,6 +93,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, }, { + name: "MiddleSizeSHA512SameFile", dataSize: 1000000, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, @@ -91,6 +101,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize}, }, { + name: "LargeSizeSHA256SeparateFile", dataSize: 4096 * int64(usermem.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, @@ -98,6 +109,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, }, { + name: "LargeSizeSHA512SeparateFile", dataSize: 4096 * int64(usermem.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, @@ -105,6 +117,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize}, }, { + name: "LargeSizeSHA256SameFile", dataSize: 4096 * int64(usermem.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, @@ -112,6 +125,7 @@ func TestLayout(t *testing.T) { expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, }, { + name: "LargeSizeSHA512SameFile", dataSize: 4096 * int64(usermem.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, @@ -121,7 +135,7 @@ func TestLayout(t *testing.T) { } for _, tc := range testCases { - t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile) if err != nil { t.Fatalf("Failed to InitLayout: %v", err) @@ -178,418 +192,883 @@ func TestGenerate(t *testing.T) { // The input data has size dataSize. It starts with the data in startWith, // and all other bytes are zeroes. testCases := []struct { - data []byte - hashAlgorithms int - expectedHash []byte + name string + data []byte + hashAlgorithms int + dataAndTreeInSameFile bool + expectedHash []byte }{ { - data: bytes.Repeat([]byte{0}, usermem.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - expectedHash: []byte{39, 30, 12, 152, 185, 58, 32, 84, 218, 79, 74, 113, 104, 219, 230, 234, 25, 126, 147, 36, 212, 44, 76, 74, 25, 93, 228, 41, 243, 143, 59, 147}, + name: "OnePageZeroesSHA256SeparateFile", + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + expectedHash: []byte{42, 197, 191, 52, 206, 122, 93, 34, 198, 125, 100, 154, 171, 177, 94, 14, 49, 40, 76, 157, 122, 58, 78, 6, 163, 248, 30, 238, 16, 190, 173, 175}, + }, + { + name: "OnePageZeroesSHA256SameFile", + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedHash: []byte{42, 197, 191, 52, 206, 122, 93, 34, 198, 125, 100, 154, 171, 177, 94, 14, 49, 40, 76, 157, 122, 58, 78, 6, 163, 248, 30, 238, 16, 190, 173, 175}, + }, + { + name: "OnePageZeroesSHA512SeparateFile", + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedHash: []byte{87, 131, 150, 74, 0, 218, 117, 114, 34, 23, 212, 16, 122, 97, 124, 172, 41, 46, 107, 150, 33, 46, 56, 39, 5, 246, 215, 187, 140, 83, 35, 63, 111, 74, 155, 241, 161, 214, 92, 141, 232, 125, 99, 71, 168, 102, 82, 20, 229, 249, 248, 28, 29, 238, 199, 223, 173, 180, 179, 46, 241, 240, 237, 74}, + }, + { + name: "OnePageZeroesSHA512SameFile", + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedHash: []byte{87, 131, 150, 74, 0, 218, 117, 114, 34, 23, 212, 16, 122, 97, 124, 172, 41, 46, 107, 150, 33, 46, 56, 39, 5, 246, 215, 187, 140, 83, 35, 63, 111, 74, 155, 241, 161, 214, 92, 141, 232, 125, 99, 71, 168, 102, 82, 20, 229, 249, 248, 28, 29, 238, 199, 223, 173, 180, 179, 46, 241, 240, 237, 74}, + }, + { + name: "MultiplePageZeroesSHA256SeparateFile", + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + expectedHash: []byte{115, 151, 35, 147, 223, 91, 17, 6, 162, 145, 237, 81, 88, 53, 120, 49, 128, 70, 188, 28, 254, 241, 19, 233, 30, 243, 71, 225, 57, 58, 61, 38}, + }, + { + name: "MultiplePageZeroesSHA256SameFile", + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedHash: []byte{115, 151, 35, 147, 223, 91, 17, 6, 162, 145, 237, 81, 88, 53, 120, 49, 128, 70, 188, 28, 254, 241, 19, 233, 30, 243, 71, 225, 57, 58, 61, 38}, + }, + { + name: "MultiplePageZeroesSHA512SeparateFile", + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedHash: []byte{41, 94, 205, 97, 254, 226, 171, 69, 76, 102, 197, 47, 113, 53, 24, 244, 103, 131, 83, 73, 87, 212, 247, 140, 32, 144, 211, 158, 25, 131, 194, 57, 21, 224, 128, 119, 69, 100, 45, 50, 157, 54, 46, 214, 152, 179, 59, 78, 28, 48, 146, 160, 204, 48, 27, 90, 152, 193, 167, 45, 150, 67, 66, 217}, }, { - data: bytes.Repeat([]byte{0}, usermem.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - expectedHash: []byte{184, 76, 172, 204, 17, 136, 127, 75, 224, 42, 251, 181, 98, 149, 1, 44, 58, 148, 20, 187, 30, 174, 73, 87, 166, 9, 109, 169, 42, 96, 87, 202, 59, 82, 174, 80, 51, 95, 101, 100, 6, 246, 56, 120, 27, 166, 29, 59, 67, 115, 227, 121, 241, 177, 63, 238, 82, 157, 43, 107, 174, 180, 44, 84}, + name: "MultiplePageZeroesSHA512SameFile", + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedHash: []byte{41, 94, 205, 97, 254, 226, 171, 69, 76, 102, 197, 47, 113, 53, 24, 244, 103, 131, 83, 73, 87, 212, 247, 140, 32, 144, 211, 158, 25, 131, 194, 57, 21, 224, 128, 119, 69, 100, 45, 50, 157, 54, 46, 214, 152, 179, 59, 78, 28, 48, 146, 160, 204, 48, 27, 90, 152, 193, 167, 45, 150, 67, 66, 217}, }, { - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - expectedHash: []byte{213, 221, 252, 9, 241, 250, 186, 1, 242, 132, 83, 77, 180, 207, 119, 48, 206, 113, 37, 253, 252, 159, 71, 70, 3, 53, 42, 244, 230, 244, 173, 143}, + name: "SingleASHA256SeparateFile", + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + expectedHash: []byte{52, 159, 140, 206, 140, 138, 231, 140, 94, 14, 252, 66, 175, 128, 191, 14, 52, 215, 190, 184, 165, 50, 182, 224, 42, 156, 145, 0, 1, 15, 187, 85}, }, { - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - expectedHash: []byte{40, 231, 187, 28, 3, 171, 168, 36, 177, 244, 118, 131, 218, 226, 106, 55, 245, 157, 244, 147, 144, 57, 41, 182, 65, 6, 13, 49, 38, 66, 237, 117, 124, 110, 250, 246, 248, 132, 201, 156, 195, 201, 142, 179, 122, 128, 195, 194, 187, 240, 129, 171, 168, 182, 101, 58, 194, 155, 99, 147, 49, 130, 161, 178}, + name: "SingleASHA256SameFile", + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedHash: []byte{52, 159, 140, 206, 140, 138, 231, 140, 94, 14, 252, 66, 175, 128, 191, 14, 52, 215, 190, 184, 165, 50, 182, 224, 42, 156, 145, 0, 1, 15, 187, 85}, }, { - data: []byte{'a'}, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - expectedHash: []byte{182, 25, 170, 240, 16, 153, 234, 4, 101, 238, 197, 154, 182, 168, 171, 96, 177, 33, 171, 117, 73, 78, 124, 239, 82, 255, 215, 121, 156, 95, 121, 171}, + name: "SingleASHA512SeparateFile", + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedHash: []byte{232, 90, 223, 95, 60, 151, 149, 172, 174, 58, 206, 97, 189, 103, 6, 202, 67, 248, 1, 189, 243, 51, 250, 42, 5, 89, 195, 9, 50, 74, 39, 169, 114, 228, 109, 225, 128, 210, 63, 94, 18, 133, 58, 48, 225, 100, 176, 55, 87, 60, 235, 224, 143, 41, 15, 253, 94, 28, 251, 233, 99, 207, 152, 108}, }, { - data: []byte{'a'}, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - expectedHash: []byte{121, 28, 140, 244, 32, 222, 61, 255, 184, 65, 117, 84, 132, 197, 122, 214, 95, 249, 164, 77, 211, 192, 217, 59, 109, 255, 249, 253, 27, 142, 110, 29, 93, 153, 92, 211, 178, 198, 136, 34, 61, 157, 141, 94, 145, 191, 201, 134, 141, 138, 51, 26, 33, 187, 17, 196, 113, 234, 125, 219, 4, 41, 57, 120}, + name: "SingleASHA512SameFile", + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedHash: []byte{232, 90, 223, 95, 60, 151, 149, 172, 174, 58, 206, 97, 189, 103, 6, 202, 67, 248, 1, 189, 243, 51, 250, 42, 5, 89, 195, 9, 50, 74, 39, 169, 114, 228, 109, 225, 128, 210, 63, 94, 18, 133, 58, 48, 225, 100, 176, 55, 87, 60, 235, 224, 143, 41, 15, 253, 94, 28, 251, 233, 99, 207, 152, 108}, }, { - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - expectedHash: []byte{17, 40, 99, 150, 206, 124, 196, 184, 41, 40, 50, 91, 113, 47, 8, 204, 2, 102, 202, 86, 157, 92, 218, 53, 151, 250, 234, 247, 191, 121, 113, 246}, + name: "OnePageASHA256SeparateFile", + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + expectedHash: []byte{157, 60, 139, 54, 248, 39, 187, 77, 31, 107, 241, 26, 240, 49, 83, 159, 182, 60, 128, 85, 121, 204, 15, 249, 44, 248, 127, 134, 58, 220, 41, 185}, + }, + { + name: "OnePageASHA256SameFile", + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedHash: []byte{157, 60, 139, 54, 248, 39, 187, 77, 31, 107, 241, 26, 240, 49, 83, 159, 182, 60, 128, 85, 121, 204, 15, 249, 44, 248, 127, 134, 58, 220, 41, 185}, }, { - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - expectedHash: []byte{100, 22, 249, 78, 47, 163, 220, 231, 228, 165, 226, 192, 221, 77, 106, 69, 115, 104, 208, 155, 124, 206, 225, 233, 98, 249, 232, 225, 114, 119, 110, 216, 117, 106, 85, 7, 200, 206, 139, 81, 116, 37, 215, 158, 89, 110, 74, 86, 66, 95, 117, 237, 70, 56, 62, 175, 48, 147, 162, 122, 253, 57, 123, 84}, + name: "OnePageASHA512SeparateFile", + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedHash: []byte{116, 22, 252, 100, 32, 241, 254, 228, 167, 228, 110, 146, 156, 189, 6, 30, 27, 127, 94, 181, 15, 98, 173, 60, 34, 102, 92, 174, 181, 80, 205, 90, 88, 12, 125, 194, 148, 175, 184, 168, 37, 66, 127, 194, 19, 132, 93, 147, 168, 217, 227, 131, 100, 25, 213, 255, 132, 60, 196, 217, 24, 158, 1, 50}, + }, + { + name: "OnePageASHA512SameFile", + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedHash: []byte{116, 22, 252, 100, 32, 241, 254, 228, 167, 228, 110, 146, 156, 189, 6, 30, 27, 127, 94, 181, 15, 98, 173, 60, 34, 102, 92, 174, 181, 80, 205, 90, 88, 12, 125, 194, 148, 175, 184, 168, 37, 66, 127, 194, 19, 132, 93, 147, 168, 217, 227, 131, 100, 25, 213, 255, 132, 60, 196, 217, 24, 158, 1, 50}, }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) { - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - params := GenerateParams{ - Size: int64(len(tc.data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - HashAlgorithms: tc.hashAlgorithms, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(tc.data) - params.File = &tree - } else { - params.File = &bytesReadWriter{ - bytes: tc.data, - } - } - hash, err := Generate(¶ms) - if err != nil { - t.Fatalf("Got err: %v, want nil", err) + t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { + var tree bytesReadWriter + params := GenerateParams{ + Size: int64(len(tc.data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + Children: make(map[string]struct{}), + HashAlgorithms: tc.hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: tc.dataAndTreeInSameFile, + } + if tc.dataAndTreeInSameFile { + tree.Write(tc.data) + params.File = &tree + } else { + params.File = &bytesReadWriter{ + bytes: tc.data, } + } + hash, err := Generate(¶ms) + if err != nil { + t.Fatalf("Got err: %v, want nil", err) + } + if !bytes.Equal(hash, tc.expectedHash) { + t.Errorf("Got hash: %v, want %v", hash, tc.expectedHash) + } + }) + } +} - if !bytes.Equal(hash, tc.expectedHash) { - t.Errorf("Got hash: %v, want %v", hash, tc.expectedHash) - } +// prepareVerify generates test data and corresponding Merkle tree, and returns +// the prepared VerifyParams. +// The test data has size dataSize. The data is hashed with hashAlgorithms. The +// portion to be verified ranges from verifyStart with verifySize. +func prepareVerify(t *testing.T, dataSize int64, hashAlgorithm int, dataAndTreeInSameFile bool, verifyStart int64, verifySize int64, out io.Writer) ([]byte, VerifyParams) { + t.Helper() + data := make([]byte, dataSize) + // Generate random bytes in data. + rand.Read(data) + + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + Children: make(map[string]struct{}), + HashAlgorithms: hashAlgorithm, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("could not generate Merkle tree:%v", err) + } + + return data, VerifyParams{ + Out: out, + File: bytes.NewReader(data), + Tree: &tree, + Size: dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + Children: make(map[string]struct{}), + HashAlgorithms: hashAlgorithm, + ReadOffset: verifyStart, + ReadSize: verifySize, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } +} + +func TestVerifyInvalidRange(t *testing.T) { + testCases := []struct { + name string + verifyStart int64 + verifySize int64 + }{ + // Verify range starts outside data range. + { + name: "StartOutsideRange", + verifyStart: usermem.PageSize, + verifySize: 1, + }, + // Verify range ends outside data range. + { + name: "EndOutsideRange", + verifyStart: 0, + verifySize: 2 * usermem.PageSize, + }, + // Verify range with negative size. + { + name: "NegativeSize", + verifyStart: 1, + verifySize: -1, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, linux.FS_VERITY_HASH_ALG_SHA256, false /* dataAndTreeInSameFile */, tc.verifyStart, tc.verifySize, &buf) + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyUnmodifiedMetadata(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + if _, err := Verify(¶ms); !errors.Is(err, nil) { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + }) + } +} + +func TestVerifyModifiedName(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.Name += "abc" + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedSize(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.Size-- + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedMode(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.Mode++ + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedUID(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.UID++ + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedGID(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.GID++ + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedChildren(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.Children["abc"] = struct{}{} + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestModifyOutsideVerifyRange(t *testing.T) { + testCases := []struct { + name string + // The byte with index modifyByte is modified. + modifyByte int64 + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "BeforeRangeSHA256SeparateFile", + modifyByte: 4*usermem.PageSize - 1, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "BeforeRangeSHA512SeparateFile", + modifyByte: 4*usermem.PageSize - 1, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "BeforeRangeSHA256SameFile", + modifyByte: 4*usermem.PageSize - 1, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "BeforeRangeSHA512SameFile", + modifyByte: 4*usermem.PageSize - 1, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + { + name: "AfterRangeSHA256SeparateFile", + modifyByte: 5 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "AfterRangeSHA512SeparateFile", + modifyByte: 5 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "AfterRangeSHA256SameFile", + modifyByte: 5 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "AfterRangeSHA256SameFile", + modifyByte: 5 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dataSize := int64(8 * usermem.PageSize) + verifyStart := int64(4 * usermem.PageSize) + verifySize := int64(usermem.PageSize) + var buf bytes.Buffer + // Modified byte is outside verify range. Verify should succeed. + data, params := prepareVerify(t, dataSize, tc.hashAlgorithm, tc.dataAndTreeInSameFile, verifyStart, verifySize, &buf) + // Flip a bit in data and checks Verify results. + data[tc.modifyByte] ^= 1 + n, err := Verify(¶ms) + if !errors.Is(err, nil) { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + if n != verifySize { + t.Errorf("Got Verify output size %d, want %d", n, verifySize) + } + if int64(buf.Len()) != verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), verifySize) + } + if !bytes.Equal(data[verifyStart:verifyStart+verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") } }) } } -func TestVerify(t *testing.T) { - // The input data has size dataSize. The portion to be verified ranges from - // verifyStart with verifySize. A bit is flipped in outOfRangeByteIndex to - // confirm that modifications outside the verification range does not cause - // issue. And a bit is flipped in modifyByte to confirm that - // modifications in the verification range is caught during verification. +func TestModifyInsideVerifyRange(t *testing.T) { testCases := []struct { - dataSize int64 + name string verifyStart int64 verifySize int64 - // A byte in input data is modified during the test. If the - // modified byte falls in verification range, Verify should - // fail, otherwise Verify should still succeed. - modifyByte int64 - modifyName bool - modifySize bool - modifyMode bool - modifyUID bool - modifyGID bool - shouldSucceed bool + // The byte with index modifyByte is modified. + modifyByte int64 + hashAlgorithm int + dataAndTreeInSameFile bool }{ - // Verify range start outside the data range should fail. - { - dataSize: usermem.PageSize, - verifyStart: usermem.PageSize, - verifySize: 1, - modifyByte: 0, - shouldSucceed: false, - }, - // Verifying range is valid if it starts inside data and ends - // outside data range, in that case start to the end of data is - // verified. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 2 * usermem.PageSize, - modifyByte: 0, - shouldSucceed: false, - }, - // Invalid verify range (negative size) should fail. - { - dataSize: usermem.PageSize, - verifyStart: 1, - verifySize: -1, - modifyByte: 0, - shouldSucceed: false, - }, - // 0 verify size should only verify metadata. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - shouldSucceed: true, - }, - // Modified name should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifyName: true, - shouldSucceed: false, - }, - // Modified size should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifySize: true, - shouldSucceed: false, - }, - // Modified mode should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifyMode: true, - shouldSucceed: false, - }, - // Modified UID should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifyUID: true, - shouldSucceed: false, - }, - // Modified GID should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifyGID: true, - shouldSucceed: false, - }, - // The test cases below use a block-aligned verify range. + // Test a block-aligned verify range. // Modifying a byte in the verified range should cause verify // to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4 * usermem.PageSize, - verifySize: usermem.PageSize, - modifyByte: 4 * usermem.PageSize, - shouldSucceed: false, + name: "BlockAlignedRangeSHA256SeparateFile", + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, }, - // Modifying a byte before the verified range should not cause - // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4 * usermem.PageSize, - verifySize: usermem.PageSize, - modifyByte: 4*usermem.PageSize - 1, - shouldSucceed: true, + name: "BlockAlignedRangeSHA512SeparateFile", + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "BlockAlignedRangeSHA256SameFile", + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, }, - // Modifying a byte after the verified range should not cause - // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4 * usermem.PageSize, - verifySize: usermem.PageSize, - modifyByte: 5 * usermem.PageSize, - shouldSucceed: true, + name: "BlockAlignedRangeSHA512SameFile", + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // The tests below use a non-block-aligned verify range. // Modifying a byte at strat of verify range should cause // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 4*usermem.PageSize + 123, - shouldSucceed: false, + name: "ModifyStartSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyStartSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyStartSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyStartSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // Modifying a byte at the end of verify range should cause // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 6*usermem.PageSize + 123, - shouldSucceed: false, + name: "ModifyEndSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyEndSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyEndSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyEndSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // Modifying a byte in the middle verified block should cause // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 5*usermem.PageSize + 123, - shouldSucceed: false, + name: "ModifyMiddleSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyMiddleSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyMiddleSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyMiddleSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // Modifying a byte in the first block in the verified range // should cause verify to fail, even the modified bit itself is // out of verify range. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 4*usermem.PageSize + 122, - shouldSucceed: false, + name: "ModifyFirstBlockSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyFirstBlockSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyFirstBlockSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyFirstBlockSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // Modifying a byte in the last block in the verified range // should cause verify to fail, even the modified bit itself is // out of verify range. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 6*usermem.PageSize + 124, - shouldSucceed: false, + name: "ModifyLastBlockSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyLastBlockSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyLastBlockSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyLastBlockSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, } - for _, tc := range testCases { - t.Run(fmt.Sprintf("%d", tc.modifyByte), func(t *testing.T) { - data := make([]byte, tc.dataSize) - // Generate random bytes in data. - rand.Read(data) - - for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - HashAlgorithms: hashAlgorithms, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, - } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - - // Flip a bit in data and checks Verify results. - var buf bytes.Buffer - data[tc.modifyByte] ^= 1 - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: tc.dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - HashAlgorithms: hashAlgorithms, - ReadOffset: tc.verifyStart, - ReadSize: tc.verifySize, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if tc.modifyName { - verifyParams.Name = defaultName + "abc" - } - if tc.modifySize { - verifyParams.Size-- - } - if tc.modifyMode { - verifyParams.Mode = defaultMode + 1 - } - if tc.modifyUID { - verifyParams.UID = defaultUID + 1 - } - if tc.modifyGID { - verifyParams.GID = defaultGID + 1 - } - if tc.shouldSucceed { - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed when expected to succeed: %v", err) - } - if n != tc.verifySize { - t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) - } - if int64(buf.Len()) != tc.verifySize { - t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) - } - if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } - } else { - if _, err := Verify(&verifyParams); err == nil { - t.Errorf("Verification succeeded when expected to fail") - } - } - } + t.Run(tc.name, func(t *testing.T) { + dataSize := int64(8 * usermem.PageSize) + var buf bytes.Buffer + data, params := prepareVerify(t, dataSize, tc.hashAlgorithm, tc.dataAndTreeInSameFile, tc.verifyStart, tc.verifySize, &buf) + // Flip a bit in data and checks Verify results. + data[tc.modifyByte] ^= 1 + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") } }) } } func TestVerifyRandom(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - // Use a random dataSize. Minimum size 2 so that we can pick a random - // portion from it. - dataSize := rand.Int63n(200*usermem.PageSize) + 2 - data := make([]byte, dataSize) - // Generate random bytes in data. - rand.Read(data) - - for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - HashAlgorithms: hashAlgorithms, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, - } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + // Use a random dataSize. Minimum size 2 so that we can pick a random + // portion from it. + dataSize := rand.Int63n(200*usermem.PageSize) + 2 // Pick a random portion of data. start := rand.Int63n(dataSize - 1) size := rand.Int63n(dataSize) + 1 var buf bytes.Buffer - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - HashAlgorithms: hashAlgorithms, - ReadOffset: start, - ReadSize: size, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + data, params := prepareVerify(t, dataSize, tc.hashAlgorithm, tc.dataAndTreeInSameFile, start, size, &buf) // Checks that the random portion of data from the original data is // verified successfully. - n, err := Verify(&verifyParams) + n, err := Verify(¶ms) if err != nil && err != io.EOF { t.Errorf("Verification failed for correct data: %v", err) } @@ -608,8 +1087,8 @@ func TestVerifyRandom(t *testing.T) { // Verify that modified metadata should fail verification. buf.Reset() - verifyParams.Name = defaultName + "abc" - if _, err := Verify(&verifyParams); err == nil { + params.Name = defaultName + "abc" + if _, err := Verify(¶ms); errors.Is(err, nil) { t.Error("Verify succeeded for modified metadata, expect failure") } @@ -617,12 +1096,12 @@ func TestVerifyRandom(t *testing.T) { buf.Reset() randBytePos := rand.Int63n(size) data[start+randBytePos] ^= 1 - verifyParams.File = bytes.NewReader(data) - verifyParams.Name = defaultName + params.File = bytes.NewReader(data) + params.Name = defaultName - if _, err := Verify(&verifyParams); err == nil { + if _, err := Verify(¶ms); errors.Is(err, nil) { t.Error("Verification succeeded for modified data, expect failure") } - } + }) } } diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 71e944c30..eadea390a 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -570,6 +570,8 @@ func (c *Client) Version() uint32 { func (c *Client) Close() { // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already // been called (by c.watch()). - c.socket.Shutdown() + if err := c.socket.Shutdown(); err != nil { + log.Warningf("Socket.Shutdown() failed (FD: %d): %v", c.socket.FD(), err) + } c.closedWg.Wait() } diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 28fe081d6..8b46a2987 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -478,28 +478,23 @@ func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) { } // Write implements part of the io.ReadWriter interface. +// +// Note that this may return a short write with a nil error. This violates the +// contract of io.Writer, but is more consistent with gVisor's pattern of +// returning errors that correspond to Linux errnos. Since short writes without +// error are common in Linux, returning a nil error is appropriate. func (r *ReadWriterFile) Write(p []byte) (int, error) { n, err := r.File.WriteAt(p, r.Offset) r.Offset += uint64(n) - if err != nil { - return n, err - } - if n < len(p) { - return n, io.ErrShortWrite - } - return n, nil + return n, err } // WriteAt implements the io.WriteAt interface. +// +// Note that this may return a short write with a nil error. This violates the +// contract of io.WriterAt. See comment on Write for justification. func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) { - n, err := r.File.WriteAt(p, uint64(offset)) - if err != nil { - return n, err - } - if n < len(p) { - return n, io.ErrShortWrite - } - return n, nil + return r.File.WriteAt(p, uint64(offset)) } // Rename implements File.Rename. diff --git a/pkg/refsvfs2/README.md b/pkg/refsvfs2/README.md new file mode 100644 index 000000000..eca53c282 --- /dev/null +++ b/pkg/refsvfs2/README.md @@ -0,0 +1,66 @@ +# Reference Counting + +Go does not offer a reliable way to couple custom resource management with +object lifetime. As a result, we need to manually implement reference counting +for many objects in gVisor to make sure that resources are acquired and released +appropriately. For example, the filesystem has many reference-counted objects +(file descriptions, dentries, inodes, etc.), and it is important that each +object persists while anything holds a reference on it and is destroyed once all +references are dropped. + +We provide a template in `refs_template.go` that can be applied to most objects +in need of reference counting. It contains a simple `Refs` struct that can be +incremented and decremented, and once the reference count reaches zero, a +destructor can be called. Note that there are some objects (e.g. `gofer.dentry`, +`overlay.dentry`) that should not immediately be destroyed upon reaching zero +references; in these cases, this template cannot be applied. + +# Reference Checking + +Unfortunately, manually keeping track of reference counts is extremely error +prone, and improper accounting can lead to production bugs that are very +difficult to root cause. + +We have several ways of discovering reference count errors in gVisor. Any +attempt to increment/decrement a `Refs` struct with a count of zero will trigger +a sentry panic, since the object should have been destroyed and become +unreachable. This allows us to identify missing increments or extra decrements, +which cause the reference count to be lower than it should be: the count will +reach zero earlier than expected, and the next increment/decrement--which should +be valid--will result in a panic. + +It is trickier to identify extra increments and missing decrements, which cause +the reference count to be higher than expected (i.e. a “reference leak”). +Reference leaks prevent resources from being released properly and can translate +to various issues that are tricky to diagnose, such as memory leaks. The +following section discusses how we implement leak checking. + +## Leak Checking + +When leak checking is enabled, reference-counted objects are added to a global +map when constructed and removed when destroyed. Near the very end of sandbox +execution, once no reference-counted objects should still be reachable, we +report everything left in the map as having leaked. Leak-checking objects +implement the `CheckedObject` interface, which allows us to print informative +warnings for each of the leaked objects. + +Leak checking is provided by `refs_template`, but objects that do not use the +template will also need to implement `CheckedObject` and be manually +registered/unregistered from the map in order to be checked. + +Note that leak checking affects performance and memory usage, so it should only +be enabled in testing environments. + +## Debugging + +Even with the checks described above, it can be difficult to track down the +exact source of a reference counting error. The error may occur far before it is +discovered (for instance, a missing `IncRef` may not be discovered until a +future `DecRef` makes the count negative). To aid in debugging, `refs_template` +provides the `enableLogging` option to log every `IncRef`, `DecRef`, and leak +check registration/unregistration, along with the object address and a call +stack. This allows us to search a log for all of the changes to a particular +object's reference count, which makes it much easier to identify the absent or +extraneous operation(s). The reference-counted objects that do not use +`refs_template` also provide logging, and others defined in the future should do +so as well. diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD index 14a8bf9cd..71c59287c 100644 --- a/pkg/sentry/devices/tundev/BUILD +++ b/pkg/sentry/devices/tundev/BUILD @@ -17,7 +17,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index ff5d49fbd..d8f4e1d35 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -16,8 +16,6 @@ package tundev import ( - "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -28,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -91,16 +88,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) - created, err := fd.device.SetIff(stack.Stack, req.Name(), flags) - if err == nil && created { - // Always start with an ARP address for interfaces so they can handle ARP - // packets. - nicID := fd.device.NICID() - if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID)) - } - } - return 0, err + return 0, fd.device.SetIff(stack.Stack, req.Name(), flags) case linux.TUNGETIFF: var req linux.IFReq diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 6b7b451b8..9379a4d7b 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -34,7 +34,6 @@ go_library( "//pkg/sentry/socket/netstack", "//pkg/syserror", "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index 19ffdec47..5227ef652 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -15,8 +15,6 @@ package dev import ( - "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -27,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -112,16 +109,7 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) - created, err := n.device.SetIff(stack.Stack, req.Name(), flags) - if err == nil && created { - // Always start with an ARP address for interfaces so they can handle ARP - // packets. - nicID := n.device.NICID() - if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID)) - } - } - return 0, err + return 0, n.device.SetIff(stack.Stack, req.Name(), flags) case linux.TUNGETIFF: var req linux.IFReq diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go index 8fe626e1c..e6171dd1d 100644 --- a/pkg/sentry/fs/proc/exec_args.go +++ b/pkg/sentry/fs/proc/exec_args.go @@ -57,16 +57,16 @@ type execArgInode struct { var _ fs.InodeOperations = (*execArgInode)(nil) // newExecArgFile creates a file containing the exec args of the given type. -func newExecArgInode(t *kernel.Task, msrc *fs.MountSource, arg execArgType) *fs.Inode { +func newExecArgInode(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, arg execArgType) *fs.Inode { if arg != cmdlineExecArg && arg != environExecArg { panic(fmt.Sprintf("unknown exec arg type %v", arg)) } f := &execArgInode{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), arg: arg, t: t, } - return newProcInode(t, f, msrc, fs.SpecialFile, t) + return newProcInode(ctx, f, msrc, fs.SpecialFile, t) } // GetFile implements fs.InodeOperations.GetFile. diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go index 45523adf8..e90da225a 100644 --- a/pkg/sentry/fs/proc/fds.go +++ b/pkg/sentry/fs/proc/fds.go @@ -95,13 +95,13 @@ var _ fs.InodeOperations = (*fd)(nil) // newFd returns a new fd based on an existing file. // // This inherits one reference to the file. -func newFd(t *kernel.Task, f *fs.File, msrc *fs.MountSource) *fs.Inode { +func newFd(ctx context.Context, t *kernel.Task, f *fs.File, msrc *fs.MountSource) *fs.Inode { fd := &fd{ // RootOwner overridden by taskOwnedInodeOps.UnstableAttrs(). - Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""), file: f, } - return newProcInode(t, fd, msrc, fs.Symlink, t) + return newProcInode(ctx, fd, msrc, fs.Symlink, t) } // GetFile returns the fs.File backing this fd. The dirent and flags @@ -153,12 +153,12 @@ type fdDir struct { var _ fs.InodeOperations = (*fdDir)(nil) // newFdDir creates a new fdDir. -func newFdDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newFdDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { f := &fdDir{ - Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermissions{User: fs.PermMask{Read: true, Execute: true}}), + Dir: *ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermissions{User: fs.PermMask{Read: true, Execute: true}}), t: t, } - return newProcInode(t, f, msrc, fs.SpecialDirectory, t) + return newProcInode(ctx, f, msrc, fs.SpecialDirectory, t) } // Check implements InodeOperations.Check. @@ -183,7 +183,7 @@ func (f *fdDir) Check(ctx context.Context, inode *fs.Inode, req fs.PermMask) boo // Lookup loads an Inode in /proc/TID/fd into a Dirent. func (f *fdDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) { n, err := walkDescriptors(f.t, p, func(file *fs.File, _ kernel.FDFlags) *fs.Inode { - return newFd(f.t, file, dir.MountSource) + return newFd(ctx, f.t, file, dir.MountSource) }) if err != nil { return nil, err @@ -237,12 +237,12 @@ type fdInfoDir struct { } // newFdInfoDir creates a new fdInfoDir. -func newFdInfoDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newFdInfoDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { fdid := &fdInfoDir{ - Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermsFromMode(0500)), + Dir: *ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0500)), t: t, } - return newProcInode(t, fdid, msrc, fs.SpecialDirectory, t) + return newProcInode(ctx, fdid, msrc, fs.SpecialDirectory, t) } // Lookup loads an fd in /proc/TID/fdinfo into a Dirent. diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 83a43aa26..03127f816 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -41,7 +41,7 @@ import ( // LINT.IfChange // newNetDir creates a new proc net entry. -func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newNetDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { k := t.Kernel() var contents map[string]*fs.Inode @@ -49,39 +49,39 @@ func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]*fs.Inode{ - "dev": seqfile.NewSeqFileInode(t, &netDev{s: s}, msrc), - "snmp": seqfile.NewSeqFileInode(t, &netSnmp{s: s}, msrc), + "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), // The following files are simple stubs until they are // implemented in netstack, if the file contains a // header the stub is just the header otherwise it is // an empty file. - "arp": newStaticProcInode(t, msrc, []byte("IP address HW type Flags HW address Mask Device\n")), + "arp": newStaticProcInode(ctx, msrc, []byte("IP address HW type Flags HW address Mask Device\n")), - "netlink": newStaticProcInode(t, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")), - "netstat": newStaticProcInode(t, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")), - "packet": newStaticProcInode(t, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")), - "protocols": newStaticProcInode(t, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")), + "netlink": newStaticProcInode(ctx, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")), + "netstat": newStaticProcInode(ctx, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")), + "packet": newStaticProcInode(ctx, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")), + "protocols": newStaticProcInode(ctx, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")), // Linux sets psched values to: nsec per usec, psched // tick in ns, 1000000, high res timer ticks per sec // (ClockGetres returns 1ns resolution). - "psched": newStaticProcInode(t, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), - "ptype": newStaticProcInode(t, msrc, []byte("Type Device Function\n")), - "route": seqfile.NewSeqFileInode(t, &netRoute{s: s}, msrc), - "tcp": seqfile.NewSeqFileInode(t, &netTCP{k: k}, msrc), - "udp": seqfile.NewSeqFileInode(t, &netUDP{k: k}, msrc), - "unix": seqfile.NewSeqFileInode(t, &netUnix{k: k}, msrc), + "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), + "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function\n")), + "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc), + "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), + "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), + "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), } if s.SupportsIPv6() { - contents["if_inet6"] = seqfile.NewSeqFileInode(t, &ifinet6{s: s}, msrc) - contents["ipv6_route"] = newStaticProcInode(t, msrc, []byte("")) - contents["tcp6"] = seqfile.NewSeqFileInode(t, &netTCP6{k: k}, msrc) - contents["udp6"] = newStaticProcInode(t, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")) + contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc) + contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte("")) + contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc) + contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")) } } - d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) - return newProcInode(t, d, msrc, fs.SpecialDirectory, t) + d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) + return newProcInode(ctx, d, msrc, fs.SpecialDirectory, t) } // ifinet6 implements seqfile.SeqSource for /proc/net/if_inet6. diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 77e0e1d26..2f2a9f920 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -179,7 +179,7 @@ func (p *proc) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dire } // Wrap it in a taskDir. - td := p.newTaskDir(otherTask, dir.MountSource, true) + td := p.newTaskDir(ctx, otherTask, dir.MountSource, true) return fs.NewDirent(ctx, td, name), nil } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 450044c9c..f43d6c221 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -79,45 +79,45 @@ type taskDir struct { var _ fs.InodeOperations = (*taskDir)(nil) // newTaskDir creates a new proc task entry. -func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { +func (p *proc) newTaskDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { contents := map[string]*fs.Inode{ - "auxv": newAuxvec(t, msrc), - "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), - "comm": newComm(t, msrc), - "cwd": newCwd(t, msrc), - "environ": newExecArgInode(t, msrc, environExecArg), - "exe": newExe(t, msrc), - "fd": newFdDir(t, msrc), - "fdinfo": newFdInfoDir(t, msrc), - "gid_map": newGIDMap(t, msrc), - "io": newIO(t, msrc, isThreadGroup), - "maps": newMaps(t, msrc), - "mem": newMem(t, msrc), - "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), - "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), - "net": newNetDir(t, msrc), - "ns": newNamespaceDir(t, msrc), - "oom_score": newOOMScore(t, msrc), - "oom_score_adj": newOOMScoreAdj(t, msrc), - "smaps": newSmaps(t, msrc), - "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), - "statm": newStatm(t, msrc), - "status": newStatus(t, msrc, p.pidns), - "uid_map": newUIDMap(t, msrc), + "auxv": newAuxvec(ctx, t, msrc), + "cmdline": newExecArgInode(ctx, t, msrc, cmdlineExecArg), + "comm": newComm(ctx, t, msrc), + "cwd": newCwd(ctx, t, msrc), + "environ": newExecArgInode(ctx, t, msrc, environExecArg), + "exe": newExe(ctx, t, msrc), + "fd": newFdDir(ctx, t, msrc), + "fdinfo": newFdInfoDir(ctx, t, msrc), + "gid_map": newGIDMap(ctx, t, msrc), + "io": newIO(ctx, t, msrc, isThreadGroup), + "maps": newMaps(ctx, t, msrc), + "mem": newMem(ctx, t, msrc), + "mountinfo": seqfile.NewSeqFileInode(ctx, &mountInfoFile{t: t}, msrc), + "mounts": seqfile.NewSeqFileInode(ctx, &mountsFile{t: t}, msrc), + "net": newNetDir(ctx, t, msrc), + "ns": newNamespaceDir(ctx, t, msrc), + "oom_score": newOOMScore(ctx, msrc), + "oom_score_adj": newOOMScoreAdj(ctx, t, msrc), + "smaps": newSmaps(ctx, t, msrc), + "stat": newTaskStat(ctx, t, msrc, isThreadGroup, p.pidns), + "statm": newStatm(ctx, t, msrc), + "status": newStatus(ctx, t, msrc, p.pidns), + "uid_map": newUIDMap(ctx, t, msrc), } if isThreadGroup { - contents["task"] = p.newSubtasks(t, msrc) + contents["task"] = p.newSubtasks(ctx, t, msrc) } if len(p.cgroupControllers) > 0 { - contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers) + contents["cgroup"] = newCGroupInode(ctx, msrc, p.cgroupControllers) } // N.B. taskOwnedInodeOps enforces dumpability-based ownership. d := &taskDir{ - Dir: *ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)), + Dir: *ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)), t: t, } - return newProcInode(t, d, msrc, fs.SpecialDirectory, t) + return newProcInode(ctx, d, msrc, fs.SpecialDirectory, t) } // subtasks represents a /proc/TID/task directory. @@ -132,13 +132,13 @@ type subtasks struct { var _ fs.InodeOperations = (*subtasks)(nil) -func (p *proc) newSubtasks(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func (p *proc) newSubtasks(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { s := &subtasks{ - Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermsFromMode(0555)), + Dir: *ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555)), t: t, p: p, } - return newProcInode(t, s, msrc, fs.SpecialDirectory, t) + return newProcInode(ctx, s, msrc, fs.SpecialDirectory, t) } // UnstableAttr returns unstable attributes of the subtasks. @@ -243,7 +243,7 @@ func (s *subtasks) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dir return nil, syserror.ENOENT } - td := s.p.newTaskDir(task, dir.MountSource, false) + td := s.p.newTaskDir(ctx, task, dir.MountSource, false) return fs.NewDirent(ctx, td, p), nil } @@ -256,12 +256,12 @@ type exe struct { t *kernel.Task } -func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newExe(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { exeSymlink := &exe{ - Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""), t: t, } - return newProcInode(t, exeSymlink, msrc, fs.Symlink, t) + return newProcInode(ctx, exeSymlink, msrc, fs.Symlink, t) } func (e *exe) executable() (file fsbridge.File, err error) { @@ -311,12 +311,12 @@ type cwd struct { t *kernel.Task } -func newCwd(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newCwd(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { cwdSymlink := &cwd{ - Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""), t: t, } - return newProcInode(t, cwdSymlink, msrc, fs.Symlink, t) + return newProcInode(ctx, cwdSymlink, msrc, fs.Symlink, t) } // Readlink implements fs.InodeOperations. @@ -355,17 +355,17 @@ type namespaceSymlink struct { t *kernel.Task } -func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs.Inode { +func newNamespaceSymlink(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, name string) *fs.Inode { // TODO(rahat): Namespace symlinks should contain the namespace name and the // inode number for the namespace instance, so for example user:[123456]. We // currently fake the inode number by sticking the symlink inode in its // place. target := fmt.Sprintf("%s:[%d]", name, device.ProcDevice.NextIno()) n := &namespaceSymlink{ - Symlink: *ramfs.NewSymlink(t, fs.RootOwner, target), + Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, target), t: t, } - return newProcInode(t, n, msrc, fs.Symlink, t) + return newProcInode(ctx, n, msrc, fs.Symlink, t) } // Readlink reads the symlink value. @@ -390,14 +390,14 @@ func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Di return fs.NewDirent(ctx, newProcInode(ctx, iops, inode.MountSource, fs.RegularFile, nil), n.Symlink.Target), nil } -func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newNamespaceDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { contents := map[string]*fs.Inode{ - "net": newNamespaceSymlink(t, msrc, "net"), - "pid": newNamespaceSymlink(t, msrc, "pid"), - "user": newNamespaceSymlink(t, msrc, "user"), + "net": newNamespaceSymlink(ctx, t, msrc, "net"), + "pid": newNamespaceSymlink(ctx, t, msrc, "pid"), + "user": newNamespaceSymlink(ctx, t, msrc, "user"), } - d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0511)) - return newProcInode(t, d, msrc, fs.SpecialDirectory, t) + d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0511)) + return newProcInode(ctx, d, msrc, fs.SpecialDirectory, t) } // memData implements fs.Inode for /proc/[pid]/mem. @@ -428,12 +428,12 @@ type memDataFile struct { t *kernel.Task } -func newMem(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newMem(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { inode := &memData{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC), t: t, } - return newProcInode(t, inode, msrc, fs.SpecialFile, t) + return newProcInode(ctx, inode, msrc, fs.SpecialFile, t) } // Truncate implements fs.InodeOperations.Truncate. @@ -489,8 +489,8 @@ type mapsData struct { t *kernel.Task } -func newMaps(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &mapsData{t}), msrc, fs.SpecialFile, t) +func newMaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &mapsData{t}), msrc, fs.SpecialFile, t) } func (md *mapsData) mm() *mm.MemoryManager { @@ -529,8 +529,8 @@ type smapsData struct { t *kernel.Task } -func newSmaps(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &smapsData{t}), msrc, fs.SpecialFile, t) +func newSmaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &smapsData{t}), msrc, fs.SpecialFile, t) } func (sd *smapsData) mm() *mm.MemoryManager { @@ -575,8 +575,8 @@ type taskStatData struct { pidns *kernel.PIDNamespace } -func newTaskStat(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool, pidns *kernel.PIDNamespace) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &taskStatData{t, showSubtasks /* tgstats */, pidns}), msrc, fs.SpecialFile, t) +func newTaskStat(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, showSubtasks bool, pidns *kernel.PIDNamespace) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &taskStatData{t, showSubtasks /* tgstats */, pidns}), msrc, fs.SpecialFile, t) } // NeedsUpdate returns whether the generation is old or not. @@ -660,8 +660,8 @@ type statmData struct { t *kernel.Task } -func newStatm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &statmData{t}), msrc, fs.SpecialFile, t) +func newStatm(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &statmData{t}), msrc, fs.SpecialFile, t) } // NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. @@ -697,8 +697,8 @@ type statusData struct { pidns *kernel.PIDNamespace } -func newStatus(t *kernel.Task, msrc *fs.MountSource, pidns *kernel.PIDNamespace) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &statusData{t, pidns}), msrc, fs.SpecialFile, t) +func newStatus(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, pidns *kernel.PIDNamespace) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &statusData{t, pidns}), msrc, fs.SpecialFile, t) } // NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. @@ -768,11 +768,11 @@ type ioData struct { ioUsage } -func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { +func newIO(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { if isThreadGroup { - return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) } - return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t) + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &ioData{t}), msrc, fs.SpecialFile, t) } // NeedsUpdate returns whether the generation is old or not. @@ -816,12 +816,12 @@ type comm struct { } // newComm returns a new comm file. -func newComm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newComm(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { c := &comm{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), t: t, } - return newProcInode(t, c, msrc, fs.SpecialFile, t) + return newProcInode(ctx, c, msrc, fs.SpecialFile, t) } // Check implements fs.InodeOperations.Check. @@ -888,12 +888,12 @@ type auxvec struct { } // newAuxvec returns a new auxvec file. -func newAuxvec(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newAuxvec(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { a := &auxvec{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), t: t, } - return newProcInode(t, a, msrc, fs.SpecialFile, t) + return newProcInode(ctx, a, msrc, fs.SpecialFile, t) } // GetFile implements fs.InodeOperations.GetFile. @@ -949,8 +949,8 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc // newOOMScore returns a oom_score file. It is a stub that always returns 0. // TODO(gvisor.dev/issue/1967) -func newOOMScore(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newStaticProcInode(t, msrc, []byte("0\n")) +func newOOMScore(ctx context.Context, msrc *fs.MountSource) *fs.Inode { + return newStaticProcInode(ctx, msrc, []byte("0\n")) } // oomScoreAdj is a file containing the oom_score adjustment for a task. @@ -979,12 +979,12 @@ type oomScoreAdjFile struct { } // newOOMScoreAdj returns a oom_score_adj file. -func newOOMScoreAdj(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newOOMScoreAdj(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { i := &oomScoreAdj{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), t: t, } - return newProcInode(t, i, msrc, fs.SpecialFile, t) + return newProcInode(ctx, i, msrc, fs.SpecialFile, t) } // Truncate implements fs.InodeOperations.Truncate. Truncate is called when diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go index 8d9517b95..2bc9485d8 100644 --- a/pkg/sentry/fs/proc/uid_gid_map.go +++ b/pkg/sentry/fs/proc/uid_gid_map.go @@ -58,18 +58,18 @@ type idMapInodeOperations struct { var _ fs.InodeOperations = (*idMapInodeOperations)(nil) // newUIDMap returns a new uid_map file. -func newUIDMap(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newIDMap(t, msrc, false /* gids */) +func newUIDMap(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newIDMap(ctx, t, msrc, false /* gids */) } // newGIDMap returns a new gid_map file. -func newGIDMap(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newIDMap(t, msrc, true /* gids */) +func newGIDMap(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newIDMap(ctx, t, msrc, true /* gids */) } -func newIDMap(t *kernel.Task, msrc *fs.MountSource, gids bool) *fs.Inode { - return newProcInode(t, &idMapInodeOperations{ - InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), +func newIDMap(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, gids bool) *fs.Inode { + return newProcInode(ctx, &idMapInodeOperations{ + InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), t: t, gids: gids, }, msrc, fs.SpecialFile, t) diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go index 8ccda1264..34d25a61e 100644 --- a/pkg/sentry/fsimpl/fuse/connection.go +++ b/pkg/sentry/fsimpl/fuse/connection.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -193,11 +192,12 @@ func (conn *connection) loadInitializedChan(closed bool) { } } -// newFUSEConnection creates a FUSE connection to fd. -func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, opts *filesystemOptions) (*connection, error) { - // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to - // mount a FUSE filesystem. - fuseFD := fd.Impl().(*DeviceFD) +// newFUSEConnection creates a FUSE connection to fuseFD. +func newFUSEConnection(_ context.Context, fuseFD *DeviceFD, opts *filesystemOptions) (*connection, error) { + // Mark the device as ready so it can be used. + // FIXME(gvisor.dev/issue/4813): fuseFD's fields are accessed without + // synchronization and without checking if fuseFD has already been used to + // mount another filesystem. // Create the writeBuf for the header to be stored in. hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go index bfde78559..1b3459c1d 100644 --- a/pkg/sentry/fsimpl/fuse/connection_control.go +++ b/pkg/sentry/fsimpl/fuse/connection_control.go @@ -198,7 +198,6 @@ func (conn *connection) Abort(ctx context.Context) { if !conn.connected { conn.asyncMu.Unlock() conn.mu.Unlock() - conn.fd.mu.Unlock() return } diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index 1b86a4b4c..1bbe6fdb7 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -94,7 +94,8 @@ type DeviceFD struct { // unprocessed in-flight requests. fullQueueCh chan struct{} `state:".(int)"` - // fs is the FUSE filesystem that this FD is being used for. + // fs is the FUSE filesystem that this FD is being used for. A reference is + // held on fs. fs *filesystem } @@ -135,12 +136,6 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R return 0, syserror.EPERM } - // Return ENODEV if the filesystem is umounted. - if fd.fs.umounted { - // TODO(gvisor.dev/issue/3525): return ECONNABORTED if aborted via fuse control fs. - return 0, syserror.ENODEV - } - // We require that any Read done on this filesystem have a sane minimum // read buffer. It must have the capacity for the fixed parts of any request // header (Linux uses the request header and the FUSEWriteIn header for this @@ -368,7 +363,7 @@ func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask { var ready waiter.EventMask - if fd.fs.umounted { + if fd.fs == nil || fd.fs.umounted { ready |= waiter.EventErr return ready & mask } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index cd0eb56e5..23e827f90 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -127,7 +127,13 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name()) return nil, nil, syserror.EINVAL } - fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + defer fuseFDGeneric.DecRef(ctx) + fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD) + if !ok { + log.Warningf("%s.GetFilesystem: device FD is %T, not a FUSE device", fsType.Name, fuseFDGeneric) + return nil, nil, syserror.EINVAL + } // Parse and set all the other supported FUSE mount options. // TODO(gVisor.dev/issue/3229): Expand the supported mount options. @@ -189,18 +195,17 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // Create a new FUSE filesystem. - fs, err := newFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) + fs, err := newFUSEFilesystem(ctx, vfsObj, &fsType, fuseFD, devMinor, &fsopts) if err != nil { log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err) return nil, nil, err } - fs.VFSFilesystem().Init(vfsObj, &fsType, fs) - // Send a FUSE_INIT request to the FUSE daemon server before returning. // This call is not blocking. if err := fs.conn.InitSend(creds, uint32(kernelTask.ThreadID())); err != nil { log.Warningf("%s.InitSend: failed with error: %v", fsType.Name(), err) + fs.VFSFilesystem().DecRef(ctx) // returned by newFUSEFilesystem return nil, nil, err } @@ -211,20 +216,28 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // newFUSEFilesystem creates a new FUSE filesystem. -func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { - conn, err := newFUSEConnection(ctx, device, opts) +func newFUSEFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, fsType *FilesystemType, fuseFD *DeviceFD, devMinor uint32, opts *filesystemOptions) (*filesystem, error) { + conn, err := newFUSEConnection(ctx, fuseFD, opts) if err != nil { log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) return nil, syserror.EINVAL } - fuseFD := device.Impl().(*DeviceFD) fs := &filesystem{ devMinor: devMinor, opts: opts, conn: conn, } + fs.VFSFilesystem().Init(vfsObj, fsType, fs) + + // FIXME(gvisor.dev/issue/4813): Doesn't conn or fs need to hold a + // reference on fuseFD, since conn uses fuseFD for communication with the + // server? Wouldn't doing so create a circular reference? + fs.VFSFilesystem().IncRef() // for fuseFD.fs + // FIXME(gvisor.dev/issue/4813): fuseFD.fs is accessed without + // synchronization. fuseFD.fs = fs + return fs, nil } diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go index b2f4276b8..2c0cc0f4e 100644 --- a/pkg/sentry/fsimpl/fuse/utils_test.go +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -52,28 +52,21 @@ func setup(t *testing.T) *testutil.System { // newTestConnection creates a fuse connection that the sentry can communicate with // and the FD for the server to communicate with. func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { - vfsObj := &vfs.VirtualFilesystem{} fuseDev := &DeviceFD{} - if err := vfsObj.Init(system.Ctx); err != nil { - return nil, nil, err - } - - vd := vfsObj.NewAnonVirtualDentry("genCountFD") + vd := system.VFS.NewAnonVirtualDentry("fuse") defer vd.DecRef(system.Ctx) - if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { + if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, nil, err } fsopts := filesystemOptions{ maxActiveRequests: maxActiveRequests, } - fs, err := newFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) + fs, err := newFUSEFilesystem(system.Ctx, system.VFS, &FilesystemType{}, fuseDev, 0, &fsopts) if err != nil { return nil, nil, err } - fs.VFSFilesystem().Init(vfsObj, nil, fs) - return fs.conn, &fuseDev.vfsfd, nil } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 7ab298019..2294c490e 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -114,6 +114,51 @@ func putDentrySlice(ds *[]*dentry) { dentrySlicePool.Put(ds) } +// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls +// dentry.checkCachingLocked on all dentries in *dsp with fs.renameMu locked +// for writing. +// +// dsp is a pointer-to-pointer since defer evaluates its arguments immediately, +// but dentry slices are allocated lazily, and it's much easier to say "defer +// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { +// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. +func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **[]*dentry) { + fs.renameMu.RUnlock() + if *dsp == nil { + return + } + ds := **dsp + // Only go through calling dentry.checkCachingLocked() (which requires + // re-locking renameMu) if we actually have any dentries with zero refs. + checkAny := false + for i := range ds { + if atomic.LoadInt64(&ds[i].refs) == 0 { + checkAny = true + break + } + } + if checkAny { + fs.renameMu.Lock() + for _, d := range ds { + d.checkCachingLocked(ctx) + } + fs.renameMu.Unlock() + } + putDentrySlice(*dsp) +} + +func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { + if *ds == nil { + fs.renameMu.Unlock() + return + } + for _, d := range **ds { + d.checkCachingLocked(ctx) + } + fs.renameMu.Unlock() + putDentrySlice(*ds) +} + // stepLocked resolves rp.Component() to an existing file, starting from the // given directory. // @@ -651,41 +696,6 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return nil } -// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls -// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for -// writing. -// -// ds is a pointer-to-pointer since defer evaluates its arguments immediately, -// but dentry slices are allocated lazily, and it's much easier to say "defer -// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { -// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { - fs.renameMu.RUnlock() - if *ds == nil { - return - } - if len(**ds) != 0 { - fs.renameMu.Lock() - for _, d := range **ds { - d.checkCachingLocked(ctx) - } - fs.renameMu.Unlock() - } - putDentrySlice(*ds) -} - -func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { - if *ds == nil { - fs.renameMu.Unlock() - return - } - for _, d := range **ds { - d.checkCachingLocked(ctx) - } - fs.renameMu.Unlock() - putDentrySlice(*ds) -} - // AccessAt implements vfs.Filesystem.Impl.AccessAt. func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { var ds *[]*dentry diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 53bcc9986..75a836899 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1352,6 +1352,11 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { } if refs > 0 { if d.cached { + // This isn't strictly necessary (fs.cachedDentries is permitted to + // contain dentries with non-zero refs, which are skipped by + // fs.evictCachedDentryLocked() upon reaching the end of the LRU), + // but since we are already holding fs.renameMu for writing we may + // as well. d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- d.cached = false diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index bc07d72c0..d55bdc97f 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -78,26 +78,36 @@ func putDentrySlice(ds *[]*dentry) { } // renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls -// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for +// dentry.checkDropLocked on all dentries in *dsp with fs.renameMu locked for // writing. // -// ds is a pointer-to-pointer since defer evaluates its arguments immediately, +// dsp is a pointer-to-pointer since defer evaluates its arguments immediately, // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { +func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) { fs.renameMu.RUnlock() - if *ds == nil { + if *dsp == nil { return } - if len(**ds) != 0 { + ds := **dsp + // Only go through calling dentry.checkDropLocked() (which requires + // re-locking renameMu) if we actually have any dentries with zero refs. + checkAny := false + for i := range ds { + if atomic.LoadInt64(&ds[i].refs) == 0 { + checkAny = true + break + } + } + if checkAny { fs.renameMu.Lock() - for _, d := range **ds { + for _, d := range ds { d.checkDropLocked(ctx) } fs.renameMu.Unlock() } - putDentrySlice(*ds) + putDentrySlice(*dsp) } func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index e001d5032..c53cc0122 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -50,7 +50,7 @@ type subtasksInode struct { var _ kernfs.Inode = (*subtasksInode)(nil) -func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) kernfs.Inode { +func (fs *filesystem) newSubtasks(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) kernfs.Inode { subInode := &subtasksInode{ fs: fs, task: task, @@ -58,7 +58,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers: cgroupControllers, } // Note: credentials are overridden by taskOwnedInode. - subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + subInode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) subInode.InitRefs() @@ -80,7 +80,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, if subTask.ThreadGroup() != i.task.ThreadGroup() { return nil, syserror.ENOENT } - return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers) + return i.fs.newTaskInode(ctx, subTask, i.pidns, false, i.cgroupControllers) } // IterDirents implements kernfs.inodeDirectory.IterDirents. diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index dc46a09bc..fea138f93 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -47,50 +47,50 @@ type taskInode struct { var _ kernfs.Inode = (*taskInode)(nil) -func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) { +func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) { if task.ExitState() == kernel.TaskExitDead { return nil, syserror.ESRCH } contents := map[string]kernfs.Inode{ - "auxv": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &auxvData{task: task}), - "cmdline": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), - "comm": fs.newComm(task, fs.NextIno(), 0444), - "cwd": fs.newCwdSymlink(task, fs.NextIno()), - "environ": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), - "exe": fs.newExeSymlink(task, fs.NextIno()), - "fd": fs.newFDDirInode(task), - "fdinfo": fs.newFDInfoDirInode(task), - "gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), - "io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), - "maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}), - "mem": fs.newMemInode(task, fs.NextIno(), 0400), - "mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}), - "mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}), - "net": fs.newTaskNetDir(task), - "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]kernfs.Inode{ - "net": fs.newNamespaceSymlink(task, fs.NextIno(), "net"), - "pid": fs.newNamespaceSymlink(task, fs.NextIno(), "pid"), - "user": fs.newNamespaceSymlink(task, fs.NextIno(), "user"), + "auxv": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &auxvData{task: task}), + "cmdline": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), + "comm": fs.newComm(ctx, task, fs.NextIno(), 0444), + "cwd": fs.newCwdSymlink(ctx, task, fs.NextIno()), + "environ": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), + "exe": fs.newExeSymlink(ctx, task, fs.NextIno()), + "fd": fs.newFDDirInode(ctx, task), + "fdinfo": fs.newFDInfoDirInode(ctx, task), + "gid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), + "io": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), + "maps": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &mapsData{task: task}), + "mem": fs.newMemInode(ctx, task, fs.NextIno(), 0400), + "mountinfo": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &mountInfoData{task: task}), + "mounts": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &mountsData{task: task}), + "net": fs.newTaskNetDir(ctx, task), + "ns": fs.newTaskOwnedDir(ctx, task, fs.NextIno(), 0511, map[string]kernfs.Inode{ + "net": fs.newNamespaceSymlink(ctx, task, fs.NextIno(), "net"), + "pid": fs.newNamespaceSymlink(ctx, task, fs.NextIno(), "pid"), + "user": fs.newNamespaceSymlink(ctx, task, fs.NextIno(), "user"), }), - "oom_score": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, newStaticFile("0\n")), - "oom_score_adj": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}), - "smaps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &smapsData{task: task}), - "stat": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), - "statm": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &statmData{task: task}), - "status": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}), - "uid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), + "oom_score": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newStaticFile("0\n")), + "oom_score_adj": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &oomScoreAdj{task: task}), + "smaps": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &smapsData{task: task}), + "stat": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), + "statm": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statmData{task: task}), + "status": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}), + "uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { - contents["task"] = fs.newSubtasks(task, pidns, cgroupControllers) + contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers) } if len(cgroupControllers) > 0 { - contents["cgroup"] = fs.newTaskOwnedInode(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) } taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. - taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + taskInode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) taskInode.InitRefs() inode := &taskOwnedInode{Inode: taskInode, owner: task} @@ -143,17 +143,17 @@ type taskOwnedInode struct { var _ kernfs.Inode = (*taskOwnedInode)(nil) -func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode { +func (fs *filesystem) newTaskOwnedInode(ctx context.Context, task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. - inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) return &taskOwnedInode{Inode: inode, owner: task} } -func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newTaskOwnedDir(ctx context.Context, task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero} - dir := kernfs.NewStaticDir(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) + dir := kernfs.NewStaticDir(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) return &taskOwnedInode{Inode: dir, owner: task} } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 3ec4471f5..02bf74dbc 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -119,7 +119,7 @@ type fdDirInode struct { var _ kernfs.Inode = (*fdDirInode)(nil) -func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { +func (fs *filesystem) newFDDirInode(ctx context.Context, task *kernel.Task) kernfs.Inode { inode := &fdDirInode{ fdDir: fdDir{ fs: fs, @@ -127,7 +127,7 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { produceSymlink: true, }, } - inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.InitRefs() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode @@ -148,7 +148,7 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err if !taskFDExists(ctx, i.task, fd) { return nil, syserror.ENOENT } - return i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()), nil + return i.fs.newFDSymlink(ctx, i.task, fd, i.fs.NextIno()), nil } // Open implements kernfs.Inode.Open. @@ -204,12 +204,12 @@ type fdSymlink struct { var _ kernfs.Inode = (*fdSymlink)(nil) -func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kernfs.Inode { +func (fs *filesystem) newFDSymlink(ctx context.Context, task *kernel.Task, fd int32, ino uint64) kernfs.Inode { inode := &fdSymlink{ task: task, fd: fd, } - inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -257,14 +257,14 @@ type fdInfoDirInode struct { var _ kernfs.Inode = (*fdInfoDirInode)(nil) -func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode { +func (fs *filesystem) newFDInfoDirInode(ctx context.Context, task *kernel.Task) kernfs.Inode { inode := &fdInfoDirInode{ fdDir: fdDir{ fs: fs, task: task, }, } - inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.InitRefs() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode @@ -284,7 +284,7 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, task: i.task, fd: fd, } - return i.fs.newTaskOwnedInode(i.task, i.fs.NextIno(), 0444, data), nil + return i.fs.newTaskOwnedInode(ctx, i.task, i.fs.NextIno(), 0444, data), nil } // IterDirents implements Inode.IterDirents. diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index ba71d0fde..a3780b222 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -248,9 +248,9 @@ type commInode struct { task *kernel.Task } -func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { +func (fs *filesystem) newComm(ctx context.Context, task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { inode := &commInode{task: task} - inode.DynamicBytesFile.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) + inode.DynamicBytesFile.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) return inode } @@ -383,10 +383,10 @@ type memInode struct { locks vfs.FileLocks } -func (fs *filesystem) newMemInode(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { +func (fs *filesystem) newMemInode(ctx context.Context, task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. inode := &memInode{task: task} - inode.init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) + inode.init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) return &taskOwnedInode{Inode: inode, owner: task} } @@ -812,9 +812,9 @@ type exeSymlink struct { var _ kernfs.Inode = (*exeSymlink)(nil) -func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode { +func (fs *filesystem) newExeSymlink(ctx context.Context, task *kernel.Task, ino uint64) kernfs.Inode { inode := &exeSymlink{task: task} - inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -888,9 +888,9 @@ type cwdSymlink struct { var _ kernfs.Inode = (*cwdSymlink)(nil) -func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode { +func (fs *filesystem) newCwdSymlink(ctx context.Context, task *kernel.Task, ino uint64) kernfs.Inode { inode := &cwdSymlink{task: task} - inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -999,7 +999,7 @@ type namespaceSymlink struct { task *kernel.Task } -func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) kernfs.Inode { +func (fs *filesystem) newNamespaceSymlink(ctx context.Context, task *kernel.Task, ino uint64, ns string) kernfs.Inode { // Namespace symlinks should contain the namespace name and the inode number // for the namespace instance, so for example user:[123456]. We currently fake // the inode number by sticking the symlink inode in its place. @@ -1007,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri inode := &namespaceSymlink{task: task} // Note: credentials are overridden by taskOwnedInode. - inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) taskInode := &taskOwnedInode{Inode: inode, owner: task} return taskInode diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 5a9ee111f..5cf8a071a 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -37,7 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode { +func (fs *filesystem) newTaskNetDir(ctx context.Context, task *kernel.Task) kernfs.Inode { k := task.Kernel() pidns := task.PIDNamespace() root := auth.NewRootCredentials(pidns.UserNamespace()) @@ -57,37 +57,37 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]kernfs.Inode{ - "dev": fs.newInode(task, root, 0444, &netDevData{stack: stack}), - "snmp": fs.newInode(task, root, 0444, &netSnmpData{stack: stack}), + "dev": fs.newInode(ctx, root, 0444, &netDevData{stack: stack}), + "snmp": fs.newInode(ctx, root, 0444, &netSnmpData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, if the file contains a header the stub is just the header // otherwise it is an empty file. - "arp": fs.newInode(task, root, 0444, newStaticFile(arp)), - "netlink": fs.newInode(task, root, 0444, newStaticFile(netlink)), - "netstat": fs.newInode(task, root, 0444, &netStatData{}), - "packet": fs.newInode(task, root, 0444, newStaticFile(packet)), - "protocols": fs.newInode(task, root, 0444, newStaticFile(protocols)), + "arp": fs.newInode(ctx, root, 0444, newStaticFile(arp)), + "netlink": fs.newInode(ctx, root, 0444, newStaticFile(netlink)), + "netstat": fs.newInode(ctx, root, 0444, &netStatData{}), + "packet": fs.newInode(ctx, root, 0444, newStaticFile(packet)), + "protocols": fs.newInode(ctx, root, 0444, newStaticFile(protocols)), // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000, // high res timer ticks per sec (ClockGetres returns 1ns resolution). - "psched": fs.newInode(task, root, 0444, newStaticFile(psched)), - "ptype": fs.newInode(task, root, 0444, newStaticFile(ptype)), - "route": fs.newInode(task, root, 0444, &netRouteData{stack: stack}), - "tcp": fs.newInode(task, root, 0444, &netTCPData{kernel: k}), - "udp": fs.newInode(task, root, 0444, &netUDPData{kernel: k}), - "unix": fs.newInode(task, root, 0444, &netUnixData{kernel: k}), + "psched": fs.newInode(ctx, root, 0444, newStaticFile(psched)), + "ptype": fs.newInode(ctx, root, 0444, newStaticFile(ptype)), + "route": fs.newInode(ctx, root, 0444, &netRouteData{stack: stack}), + "tcp": fs.newInode(ctx, root, 0444, &netTCPData{kernel: k}), + "udp": fs.newInode(ctx, root, 0444, &netUDPData{kernel: k}), + "unix": fs.newInode(ctx, root, 0444, &netUnixData{kernel: k}), } if stack.SupportsIPv6() { - contents["if_inet6"] = fs.newInode(task, root, 0444, &ifinet6{stack: stack}) - contents["ipv6_route"] = fs.newInode(task, root, 0444, newStaticFile("")) - contents["tcp6"] = fs.newInode(task, root, 0444, &netTCP6Data{kernel: k}) - contents["udp6"] = fs.newInode(task, root, 0444, newStaticFile(upd6)) + contents["if_inet6"] = fs.newInode(ctx, root, 0444, &ifinet6{stack: stack}) + contents["ipv6_route"] = fs.newInode(ctx, root, 0444, newStaticFile("")) + contents["tcp6"] = fs.newInode(ctx, root, 0444, &netTCP6Data{kernel: k}) + contents["udp6"] = fs.newInode(ctx, root, 0444, newStaticFile(upd6)) } } - return fs.newTaskOwnedDir(task, fs.NextIno(), 0555, contents) + return fs.newTaskOwnedDir(ctx, task, fs.NextIno(), 0555, contents) } // ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6. diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 151d1f10d..fdc580610 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -118,7 +118,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err return nil, syserror.ENOENT } - return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers) + return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers) } // IterDirents implements kernfs.inodeDirectory.IterDirents. diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 738c0c9cc..205ad8192 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -136,7 +136,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns config := &kernel.TaskConfig{ Kernel: k, ThreadGroup: tc, - TaskContext: &kernel.TaskContext{Name: name, MemoryManager: m}, + TaskImage: &kernel.TaskImage{Name: name, MemoryManager: m}, Credentials: auth.CredentialsFromContext(ctx), NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index e39cd305b..61138a7a4 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -381,6 +381,8 @@ afterTrailingSymlink: creds := rp.Credentials() child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)) parentDir.insertChildLocked(child, name) + child.IncRef() + defer child.DecRef(ctx) unlock() fd, err := child.open(ctx, rp, &opts, true) if err != nil { diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 98680fde9..f8e0cffb0 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -565,7 +565,7 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. // -// Preconditions: inode.mu must be held. +// Preconditions: rw.file.inode.mu must be held. func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { // Hold dataMu so we can modify size. rw.file.dataMu.Lock() @@ -657,7 +657,7 @@ exitLoop: // If the write ends beyond the file's previous size, it causes the // file to grow. if rw.off > rw.file.size { - rw.file.size = rw.off + atomic.StoreUint64(&rw.file.size, rw.off) } return done, retErr diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 85a3dfe20..0c9c639d3 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -478,9 +478,9 @@ func (i *inode) statTo(stat *linux.Statx) { stat.GID = atomic.LoadUint32(&i.gid) stat.Mode = uint16(atomic.LoadUint32(&i.mode)) stat.Ino = i.ino - stat.Atime = linux.NsecToStatxTimestamp(i.atime) - stat.Ctime = linux.NsecToStatxTimestamp(i.ctime) - stat.Mtime = linux.NsecToStatxTimestamp(i.mtime) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&i.atime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&i.ctime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&i.mtime)) stat.DevMajor = linux.UNNAMED_MAJOR stat.DevMinor = i.fs.devMinor switch impl := i.impl.(type) { @@ -631,7 +631,8 @@ func (i *inode) direntType() uint8 { } func (i *inode) isDir() bool { - return linux.FileMode(i.mode).FileType() == linux.S_IFDIR + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + return mode.FileType() == linux.S_IFDIR } func (i *inode) touchAtime(mnt *vfs.Mount) { diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 4e8d63d51..59fcff498 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -16,6 +16,7 @@ package verity import ( "bytes" + "encoding/json" "fmt" "io" "strconv" @@ -31,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -105,8 +107,10 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de // Dentries which may have a reference count of zero, and which therefore // should be dropped once traversal is complete, are appended to ds. // -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. -// !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !rp.Done(). func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { if !d.isDir() { return nil, syserror.ENOTDIR @@ -156,15 +160,19 @@ afterSymlink: return child, nil } -// verifyChild verifies the hash of child against the already verified hash of -// the parent to ensure the child is expected. verifyChild triggers a sentry -// panic if unexpected modifications to the file system are detected. In +// verifyChildLocked verifies the hash of child against the already verified +// hash of the parent to ensure the child is expected. verifyChild triggers a +// sentry panic if unexpected modifications to the file system are detected. In // noCrashOnVerificationFailure mode it returns a syserror instead. -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// // TODO(b/166474175): Investigate all possible errors returned in this // function, and make sure we differentiate all errors that indicate unexpected // modifications to the file system from the ones that are not harmful. -func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { +func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() // Get the path to the child dentry. This is only used to provide path @@ -266,35 +274,44 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contain the hash of the children in the parent Merkle tree when // Verify returns with success. var buf bytes.Buffer - if _, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, + parent.hashMu.RLock() + _, err = merkletree.Verify(&merkletree.VerifyParams{ + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + Children: parent.childrenNames, //TODO(b/156980949): Support passing other hash algorithms. HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), Expected: parent.hash, DataAndTreeInSameFile: true, - }); err != nil && err != io.EOF { + }) + parent.hashMu.RUnlock() + if err != nil && err != io.EOF { return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. + child.hashMu.Lock() if len(child.hash) == 0 { child.hash = buf.Bytes() } + child.hashMu.Unlock() return child, nil } -// verifyStat verifies the stat against the verified hash. The mode/uid/gid of -// the file is cached after verified. -func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Statx) error { +// verifyStatAndChildrenLocked verifies the stat and children names against the +// verified hash. The mode/uid/gid and childrenNames of the file is cached +// after verified. +// +// Preconditions: d.dirMu must be locked. +func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry, stat linux.Statx) error { vfsObj := fs.vfsfs.VirtualFilesystem() // Get the path to the child dentry. This is only used to provide path @@ -337,20 +354,65 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } + if d.isDir() && len(d.childrenNames) == 0 { + childrenOffString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: childrenOffsetXattr, + Size: sizeOfStringInt32, + }) + + if err == syserror.ENODATA { + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) + } + if err != nil { + return err + } + childrenOffset, err := strconv.Atoi(childrenOffString) + if err != nil { + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + } + + childrenSizeString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: childrenSizeXattr, + Size: sizeOfStringInt32, + }) + + if err == syserror.ENODATA { + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) + } + if err != nil { + return err + } + childrenSize, err := strconv.Atoi(childrenSizeString) + if err != nil { + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + } + + childrenNames := make([]byte, childrenSize) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(childrenOffset), vfs.ReadOptions{}); err != nil { + return alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err)) + } + + if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { + return alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err)) + } + } + fdReader := vfs.FileReadWriteSeeker{ FD: fd, Ctx: ctx, } var buf bytes.Buffer + d.hashMu.RLock() params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + Children: d.childrenNames, //TODO(b/156980949): Support passing other hash algorithms. HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: 0, @@ -359,6 +421,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat Expected: d.hash, DataAndTreeInSameFile: false, } + d.hashMu.RUnlock() if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR { params.DataAndTreeInSameFile = true } @@ -373,7 +436,9 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat return nil } -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { // If verity is enabled on child, we should check again whether @@ -422,7 +487,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // be cached before enabled. if fs.allowRuntimeEnable { if parent.verityEnabled() { - if _, err := fs.verifyChild(ctx, parent, child); err != nil { + if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil { return nil, err } } @@ -438,7 +503,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s if err != nil { return nil, err } - if err := fs.verifyStat(ctx, child, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil { return nil, err } } @@ -458,96 +523,64 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s return child, nil } -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() - childVD, childErr := parent.getLowerAt(ctx, vfsObj, name) - // We will handle ENOENT separately, as it may indicate unexpected - // modifications to the file system, and may cause a sentry panic. - if childErr != nil && childErr != syserror.ENOENT { - return nil, childErr + if parent.verityEnabled() { + if _, ok := parent.childrenNames[name]; !ok { + return nil, syserror.ENOENT + } } - // The dentry needs to be cleaned up if any error occurs. IncRef will be - // called if a verity child dentry is successfully created. - if childErr == nil { - defer childVD.DecRef(ctx) + parentPath, err := vfsObj.PathnameWithDeleted(ctx, parent.fs.rootDentry.lowerVD, parent.lowerVD) + if err != nil { + return nil, err } - childMerkleVD, childMerkleErr := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) - // We will handle ENOENT separately, as it may indicate unexpected - // modifications to the file system, and may cause a sentry panic. - if childMerkleErr != nil && childMerkleErr != syserror.ENOENT { - return nil, childMerkleErr + childVD, err := parent.getLowerAt(ctx, vfsObj, name) + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) + } + if err != nil { + return nil, err } // The dentry needs to be cleaned up if any error occurs. IncRef will be // called if a verity child dentry is successfully created. - if childMerkleErr == nil { - defer childMerkleVD.DecRef(ctx) - } + defer childVD.DecRef(ctx) - // Get the path to the parent dentry. This is only used to provide path - // information in failure case. - parentPath, err := vfsObj.PathnameWithDeleted(ctx, parent.fs.rootDentry.lowerVD, parent.lowerVD) - if err != nil { + childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + if err == syserror.ENOENT { + if !fs.allowRuntimeEnable { + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) + } + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(merklePrefix + name), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + if err != nil { + return nil, err + } + } + if err != nil && err != syserror.ENOENT { return nil, err } - // TODO(b/166474175): Investigate all possible errors of childErr and - // childMerkleErr, and make sure we differentiate all errors that - // indicate unexpected modifications to the file system from the ones - // that are not harmful. - if childErr == syserror.ENOENT && childMerkleErr == nil { - // Failed to get child file/directory dentry. However the - // corresponding Merkle tree is found. This indicates an - // unexpected modification to the file system that - // removed/renamed the child. - return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) - } else if childErr == nil && childMerkleErr == syserror.ENOENT { - // If in allowRuntimeEnable mode, and the Merkle tree file is - // not created yet, we create an empty Merkle tree file, so that - // if the file is enabled through ioctl, we have the Merkle tree - // file open and ready to use. - // This may cause empty and unused Merkle tree files in - // allowRuntimeEnable mode, if they are never enabled. This - // does not affect verification, as we rely on cached hash to - // decide whether to perform verification, not the existence of - // the Merkle tree file. Also, those Merkle tree files are - // always hidden and cannot be accessed by verity fs users. - if fs.allowRuntimeEnable { - childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(merklePrefix + name), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT, - Mode: 0644, - }) - if err != nil { - return nil, err - } - childMerkleFD.DecRef(ctx) - childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) - if err != nil { - return nil, err - } - } else { - // If runtime enable is not allowed. This indicates an - // unexpected modification to the file system that - // removed/renamed the Merkle tree file. - return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) - } - } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { - // Both the child and the corresponding Merkle tree are missing. - // This could be an unexpected modification or due to incorrect - // parameter. - // TODO(b/167752508): Investigate possible ways to differentiate - // cases that both files are deleted from cases that they never - // exist in the file system. - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) - } + // The dentry needs to be cleaned up if any error occurs. IncRef will be + // called if a verity child dentry is successfully created. + defer childMerkleVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{ @@ -577,18 +610,19 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, child.mode = uint32(stat.Mode) child.uid = stat.UID child.gid = stat.GID + child.childrenNames = make(map[string]struct{}) // Verify child hash. This should always be performed unless in // allowRuntimeEnable mode and the parent directory hasn't been enabled // yet. if parent.verityEnabled() { - if _, err := fs.verifyChild(ctx, parent, child); err != nil { + if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil { child.destroyLocked(ctx) return nil, err } } if child.verityEnabled() { - if err := fs.verifyStat(ctx, child, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil { child.destroyLocked(ctx) return nil, err } @@ -602,7 +636,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // rp.Start().Impl().(*dentry)). It does not check that the returned directory // is searchable by the provider of rp. // -// Preconditions: fs.renameMu must be locked. !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() @@ -943,11 +979,13 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } + d.dirMu.Lock() if d.verityEnabled() { - if err := fs.verifyStat(ctx, d, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { return linux.Statx{}, err } } + d.dirMu.Unlock() return stat, nil } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index faa862c55..46346e54d 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -19,9 +19,22 @@ // The verity file system is read-only, except for one case: when // allowRuntimeEnable is true, additional Merkle files can be generated using // the FS_IOC_ENABLE_VERITY ioctl. +// +// Lock order: +// +// filesystem.renameMu +// dentry.dirMu +// fileDescription.mu +// filesystem.verityMu +// dentry.hashMu +// +// Locking dentry.dirMu in multiple dentries requires that parent dentries are +// locked before child dentries, and that filesystem.renameMu is locked to +// stabilize this relationship. package verity import ( + "encoding/json" "fmt" "math" "strconv" @@ -60,6 +73,15 @@ const ( // file size. For a directory, this is the size of all its children's hashes. merkleSizeXattr = "user.merkle.size" + // childrenOffsetXattr is the extended attribute name specifying the + // names of the offset of the serialized children names in the Merkle + // tree file. + childrenOffsetXattr = "user.merkle.childrenOffset" + + // childrenSizeXattr is the extended attribute name specifying the size + // of the serialized children names. + childrenSizeXattr = "user.merkle.childrenSize" + // sizeOfStringInt32 is the size for a 32 bit integer stored as string in // extended attributes. The maximum value of a 32 bit integer has 10 digits. sizeOfStringInt32 = 10 @@ -299,14 +321,77 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt d.uid = stat.UID d.gid = stat.GID d.hash = make([]byte, len(iopts.RootHash)) + d.childrenNames = make(map[string]struct{}) if !fs.allowRuntimeEnable { - if err := fs.verifyStat(ctx, d, stat); err != nil { + // Get children names from the underlying file system. + offString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.GetXattrOptions{ + Name: childrenOffsetXattr, + Size: sizeOfStringInt32, + }) + if err == syserror.ENOENT || err == syserror.ENODATA { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) + } + if err != nil { + return nil, nil, err + } + + off, err := strconv.Atoi(offString) + if err != nil { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + } + + sizeString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.GetXattrOptions{ + Name: childrenSizeXattr, + Size: sizeOfStringInt32, + }) + if err == syserror.ENOENT || err == syserror.ENODATA { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) + } + if err != nil { + return nil, nil, err + } + size, err := strconv.Atoi(sizeString) + if err != nil { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + } + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err == syserror.ENOENT { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) + } + if err != nil { + return nil, nil, err + } + + childrenNames := make([]byte, size) + if _, err := lowerMerkleFD.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(off), vfs.ReadOptions{}); err != nil { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err)) + } + + if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) + } + + if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { return nil, nil, err } } + d.hashMu.Lock() copy(d.hash, iopts.RootHash) + d.hashMu.Unlock() d.vfsd.Init(d) fs.rootDentry = d @@ -331,7 +416,8 @@ type dentry struct { fs *filesystem // mode, uid, gid and size are the file mode, owner, group, and size of - // the file in the underlying file system. + // the file in the underlying file system. They are set when a dentry + // is initialized, and never modified. mode uint32 uid uint32 gid uint32 @@ -352,15 +438,24 @@ type dentry struct { dirMu sync.Mutex `state:"nosave"` children map[string]*dentry - // lowerVD is the VirtualDentry in the underlying file system. + // childrenNames stores the name of all children of the dentry. This is + // used by verity to check whether a child is expected. This is only + // populated by enableVerity. childrenNames is also protected by dirMu. + childrenNames map[string]struct{} + + // lowerVD is the VirtualDentry in the underlying file system. It is + // never modified after initialized. lowerVD vfs.VirtualDentry // lowerMerkleVD is the VirtualDentry of the corresponding Merkle tree - // in the underlying file system. + // in the underlying file system. It is never modified after + // initialized. lowerMerkleVD vfs.VirtualDentry - // hash is the calculated hash for the current file or directory. - hash []byte + // hash is the calculated hash for the current file or directory. hash + // is protected by hashMu. + hashMu sync.RWMutex `state:"nosave"` + hash []byte } // newDentry creates a new dentry representing the given verity file. The @@ -443,7 +538,9 @@ func (d *dentry) checkDropLocked(ctx context.Context) { // destroyLocked destroys the dentry. // -// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. +// Preconditions: +// * d.fs.renameMu must be locked for writing. +// * d.refs == 0. func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: @@ -523,6 +620,8 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) // mode, it returns true if the target has been enabled with // ioctl(FS_IOC_ENABLE_VERITY). func (d *dentry) verityEnabled() bool { + d.hashMu.RLock() + defer d.hashMu.RUnlock() return !d.fs.allowRuntimeEnable || len(d.hash) != 0 } @@ -602,11 +701,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu if err != nil { return linux.Statx{}, err } + fd.d.dirMu.Lock() if fd.d.verityEnabled() { - if err := fd.d.fs.verifyStat(ctx, fd.d, stat); err != nil { + if err := fd.d.fs.verifyStatAndChildrenLocked(ctx, fd.d, stat); err != nil { return linux.Statx{}, err } } + fd.d.dirMu.Unlock() return stat, nil } @@ -642,13 +743,15 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) return offset, nil } -// generateMerkle generates a Merkle tree file for fd. If fd points to a file -// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash -// of the generated Merkle tree and the data size is returned. If fd points to -// a regular file, the data is the content of the file. If fd points to a -// directory, the data is all hahes of its children, written to the Merkle tree -// file. -func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) { +// generateMerkleLocked generates a Merkle tree file for fd. If fd points to a +// file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The +// hash of the generated Merkle tree and the data size is returned. If fd +// points to a regular file, the data is the content of the file. If fd points +// to a directory, the data is all hahes of its children, written to the Merkle +// tree file. +// +// Preconditions: fd.d.fs.verityMu must be locked. +func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) { fdReader := vfs.FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, @@ -664,6 +767,7 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, params := &merkletree.GenerateParams{ TreeReader: &merkleReader, TreeWriter: &merkleWriter, + Children: fd.d.childrenNames, //TODO(b/156980949): Support passing other hash algorithms. HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), } @@ -716,9 +820,48 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, return hash, uint64(params.Size), err } +// recordChildrenLocked writes the names of fd's children into the +// corresponding Merkle tree file, and saves the offset/size of the map into +// xattrs. +// +// Preconditions: +// * fd.d.fs.verityMu must be locked. +// * fd.d.isDir() == true. +func (fd *fileDescription) recordChildrenLocked(ctx context.Context) error { + // Record the children names in the Merkle tree file. + childrenNames, err := json.Marshal(fd.d.childrenNames) + if err != nil { + return err + } + + stat, err := fd.merkleWriter.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return err + } + + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: childrenOffsetXattr, + Value: strconv.Itoa(int(stat.Size)), + }); err != nil { + return err + } + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: childrenSizeXattr, + Value: strconv.Itoa(len(childrenNames)), + }); err != nil { + return err + } + + if _, err = fd.merkleWriter.Write(ctx, usermem.BytesIOSequence(childrenNames), vfs.WriteOptions{}); err != nil { + return err + } + + return nil +} + // enableVerity enables verity features on fd by generating a Merkle tree file // and stores its hash in its parent directory's Merkle tree. -func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) { +func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { if !fd.d.fs.allowRuntimeEnable { return 0, syserror.EPERM } @@ -734,7 +877,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } - hash, dataSize, err := fd.generateMerkle(ctx) + hash, dataSize, err := fd.generateMerkleLocked(ctx) if err != nil { return 0, err } @@ -761,6 +904,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui }); err != nil { return 0, err } + + // Add the current child's name to parent's childrenNames. + fd.d.parent.childrenNames[fd.d.name] = struct{}{} } // Record the size of the data being hashed for fd. @@ -770,18 +916,29 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui }); err != nil { return 0, err } - fd.d.hash = append(fd.d.hash, hash...) + + if fd.d.isDir() { + if err := fd.recordChildrenLocked(ctx); err != nil { + return 0, err + } + } + fd.d.hashMu.Lock() + fd.d.hash = hash + fd.d.hashMu.Unlock() return 0, nil } // measureVerity returns the hash of fd, saved in verityDigest. -func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { +func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest usermem.Addr) (uintptr, error) { t := kernel.TaskFromContext(ctx) if t == nil { return 0, syserror.EINVAL } var metadata linux.DigestMetadata + fd.d.hashMu.RLock() + defer fd.d.hashMu.RUnlock() + // If allowRuntimeEnable is true, an empty fd.d.hash indicates that // verity is not enabled for the file. If allowRuntimeEnable is false, // this is an integrity violation because all files should have verity @@ -815,14 +972,16 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve return 0, err } -func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) { +func (fd *fileDescription) verityFlags(ctx context.Context, flags usermem.Addr) (uintptr, error) { f := int32(0) + fd.d.hashMu.RLock() // All enabled files should store a hash. This flag is not settable via // FS_IOC_SETFLAGS. if len(fd.d.hash) != 0 { f |= linux.FS_VERITY_FL } + fd.d.hashMu.RUnlock() t := kernel.TaskFromContext(ctx) if t == nil { @@ -836,11 +995,11 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := args[1].Uint(); cmd { case linux.FS_IOC_ENABLE_VERITY: - return fd.enableVerity(ctx, uio) + return fd.enableVerity(ctx) case linux.FS_IOC_MEASURE_VERITY: - return fd.measureVerity(ctx, uio, args[2].Pointer()) + return fd.measureVerity(ctx, args[2].Pointer()) case linux.FS_IOC_GETFLAGS: - return fd.verityFlags(ctx, uio, args[2].Pointer()) + return fd.verityFlags(ctx, args[2].Pointer()) default: // TODO(b/169682228): Investigate which ioctl commands should // be allowed. @@ -901,15 +1060,17 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of Ctx: ctx, } + fd.d.hashMu.RLock() n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + Children: fd.d.childrenNames, //TODO(b/156980949): Support passing other hash algorithms. HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, @@ -917,6 +1078,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of Expected: fd.d.hash, DataAndTreeInSameFile: false, }) + fd.d.hashMu.RUnlock() if err != nil { return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index b2da9dd96..7196e74eb 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -18,6 +18,7 @@ import ( "fmt" "io" "math/rand" + "strconv" "testing" "time" @@ -307,6 +308,56 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) { } } +// TestOpenNonexistentFile ensures that opening a nonexistent file does not +// trigger verification failure, even if the parent directory is verified. +func TestOpenNonexistentFile(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t, SHA256) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Enable verity on the parent directory. + parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Ensure open an unexpected file in the parent directory fails with + // ENOENT rather than verification failure. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename + "abc"), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != syserror.ENOENT { + t.Errorf("OpenAt unexpected error: %v", err) + } +} + // TestPReadModifiedFileFails ensures that read from a modified verity file // fails. func TestPReadModifiedFileFails(t *testing.T) { @@ -439,10 +490,10 @@ func TestModifiedMerkleFails(t *testing.T) { // Flip a random bit in the Merkle tree file. stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) if err != nil { - t.Fatalf("stat: %v", err) + t.Errorf("lowerMerkleFD.Stat: %v", err) } - merkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { + + if err := corruptRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { t.Fatalf("corruptRandomBit: %v", err) } @@ -510,11 +561,17 @@ func TestModifiedParentMerkleFails(t *testing.T) { // This parent directory contains only one child, so any random // modification in the parent Merkle tree should cause verification // failure when opening the child file. - stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + sizeString, err := parentLowerMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: childrenOffsetXattr, + Size: sizeOfStringInt32, + }) + if err != nil { + t.Fatalf("parentLowerMerkleFD.GetXattr: %v", err) + } + parentMerkleSize, err := strconv.Atoi(sizeString) if err != nil { - t.Fatalf("stat: %v", err) + t.Fatalf("Failed convert size to int: %v", err) } - parentMerkleSize := int(stat.Size) if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { t.Fatalf("corruptRandomBit: %v", err) } @@ -602,124 +659,165 @@ func TestModifiedStatFails(t *testing.T) { } } -// TestOpenDeletedOrRenamedFileFails ensures that opening a deleted/renamed -// verity enabled file or the corresponding Merkle tree file fails with the -// verify error. +// TestOpenDeletedFileFails ensures that opening a deleted verity enabled file +// and/or the corresponding Merkle tree file fails with the verity error. func TestOpenDeletedFileFails(t *testing.T) { testCases := []struct { - // Tests removing files is remove is true. Otherwise tests - // renaming files. - remove bool - // The original file is removed/renamed if changeFile is true. + // The original file is removed if changeFile is true. changeFile bool - // The Merkle tree file is removed/renamed if changeMerkleFile - // is true. + // The Merkle tree file is removed if changeMerkleFile is true. changeMerkleFile bool }{ { - remove: true, changeFile: true, changeMerkleFile: false, }, { - remove: true, changeFile: false, changeMerkleFile: true, }, { - remove: false, changeFile: true, - changeMerkleFile: false, + changeMerkleFile: true, }, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("changeFile:%t, changeMerkleFile:%t", tc.changeFile, tc.changeMerkleFile), func(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t, SHA256) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD + if tc.changeFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } + }) + } +} + +// TestOpenRenamedFileFails ensures that opening a renamed verity enabled file +// and/or the corresponding Merkle tree file fails with the verity error. +func TestOpenRenamedFileFails(t *testing.T) { + testCases := []struct { + // The original file is renamed if changeFile is true. + changeFile bool + // The Merkle tree file is renamed if changeMerkleFile is true. + changeMerkleFile bool + }{ { - remove: false, changeFile: true, changeMerkleFile: false, }, + { + changeFile: false, + changeMerkleFile: true, + }, + { + changeFile: true, + changeMerkleFile: true, + }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } + t.Run(fmt.Sprintf("changeFile:%t, changeMerkleFile:%t", tc.changeFile, tc.changeMerkleFile), func(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t, SHA256) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } - rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD - if tc.remove { - if tc.changeFile { - if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(filename), - }); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - } - if tc.changeMerkleFile { - if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + filename), - }); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - } - } else { - newFilename := "renamed-test-file" - if tc.changeFile { - if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(filename), - }, &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(newFilename), - }, &vfs.RenameOptions{}); err != nil { - t.Fatalf("RenameAt: %v", err) - } - } - if tc.changeMerkleFile { - if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + filename), - }, &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + newFilename), - }, &vfs.RenameOptions{}); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - } - } + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } - // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, + rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD + newFilename := "renamed-test-file" + if tc.changeFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != syserror.EIO { - t.Errorf("got OpenAt error: %v, expected EIO", err) + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("RenameAt: %v", err) } } + if tc.changeMerkleFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } }) } } diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 90dd4a047..0ee60569c 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -184,6 +184,7 @@ go_library( "task_exit.go", "task_futex.go", "task_identity.go", + "task_image.go", "task_list.go", "task_log.go", "task_net.go", @@ -224,6 +225,7 @@ go_library( "//pkg/cpuid", "//pkg/eventchannel", "//pkg/fspath", + "//pkg/goid", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/kernel/aio.go b/pkg/sentry/kernel/aio.go index 0ac78c0b8..ec36d1a49 100644 --- a/pkg/sentry/kernel/aio.go +++ b/pkg/sentry/kernel/aio.go @@ -15,10 +15,7 @@ package kernel import ( - "time" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" ) // AIOCallback is an function that does asynchronous I/O on behalf of a task. @@ -26,7 +23,7 @@ type AIOCallback func(context.Context) // QueueAIO queues an AIOCallback which will be run asynchronously. func (t *Task) QueueAIO(cb AIOCallback) { - ctx := taskAsyncContext{t: t} + ctx := t.AsyncContext() wg := &t.TaskSet().aioGoroutines wg.Add(1) go func() { @@ -34,48 +31,3 @@ func (t *Task) QueueAIO(cb AIOCallback) { wg.Done() }() } - -type taskAsyncContext struct { - context.NoopSleeper - t *Task -} - -// Debugf implements log.Logger.Debugf. -func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) { - ctx.t.Debugf(format, v...) -} - -// Infof implements log.Logger.Infof. -func (ctx taskAsyncContext) Infof(format string, v ...interface{}) { - ctx.t.Infof(format, v...) -} - -// Warningf implements log.Logger.Warningf. -func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) { - ctx.t.Warningf(format, v...) -} - -// IsLogging implements log.Logger.IsLogging. -func (ctx taskAsyncContext) IsLogging(level log.Level) bool { - return ctx.t.IsLogging(level) -} - -// Deadline implements context.Context.Deadline. -func (ctx taskAsyncContext) Deadline() (time.Time, bool) { - return ctx.t.Deadline() -} - -// Done implements context.Context.Done. -func (ctx taskAsyncContext) Done() <-chan struct{} { - return ctx.t.Done() -} - -// Err implements context.Context.Err. -func (ctx taskAsyncContext) Err() error { - return ctx.t.Err() -} - -// Value implements context.Context.Value. -func (ctx taskAsyncContext) Value(key interface{}) interface{} { - return ctx.t.Value(key) -} diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index bb94769c4..a8596410f 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -15,8 +15,6 @@ package kernel import ( - "time" - "gvisor.dev/gvisor/pkg/context" ) @@ -98,18 +96,3 @@ func TaskFromContext(ctx context.Context) *Task { } return nil } - -// Deadline implements context.Context.Deadline. -func (*Task) Deadline() (time.Time, bool) { - return time.Time{}, false -} - -// Done implements context.Context.Done. -func (*Task) Done() <-chan struct{} { - return nil -} - -// Err implements context.Context.Err. -func (*Task) Err() error { - return nil -} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 9b2be44d4..2cdcdfc1f 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -632,7 +632,7 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { defer k.tasks.mu.RUnlock() for t := range k.tasks.Root.tids { // We can skip locking Task.mu here since the kernel is paused. - if mm := t.tc.MemoryManager; mm != nil { + if mm := t.image.MemoryManager; mm != nil { if _, ok := invalidated[mm]; !ok { if err := mm.InvalidateUnsavable(ctx); err != nil { return err @@ -642,7 +642,7 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { } // I really wish we just had a sync.Map of all MMs... if r, ok := t.runState.(*runSyscallAfterExecStop); ok { - if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { + if err := r.image.MemoryManager.InvalidateUnsavable(ctx); err != nil { return err } } @@ -1017,7 +1017,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, Features: k.featureSet, } - tc, se := k.LoadTaskImage(ctx, loadArgs) + image, se := k.LoadTaskImage(ctx, loadArgs) if se != nil { return nil, 0, errors.New(se.String()) } @@ -1030,7 +1030,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, config := &TaskConfig{ Kernel: k, ThreadGroup: tg, - TaskContext: tc, + TaskImage: image, FSContext: fsContext, FDTable: args.FDTable, Credentials: args.Credentials, @@ -1046,7 +1046,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, if err != nil { return nil, 0, err } - t.traceExecEvent(tc) // Simulate exec for tracing. + t.traceExecEvent(image) // Simulate exec for tracing. // Success. cu.Release() @@ -1359,6 +1359,13 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error { // not have meaningful trace data. Rebuilding here ensures that we can do so // after tracing has been enabled. func (k *Kernel) RebuildTraceContexts() { + // We need to pause all task goroutines because Task.rebuildTraceContext() + // replaces Task.traceContext and Task.traceTask, which are + // task-goroutine-exclusive (i.e. the task goroutine assumes that it can + // access them without synchronization) for performance. + k.Pause() + defer k.Unpause() + k.extMu.Lock() defer k.extMu.Unlock() k.tasks.mu.RLock() diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index 387edfa91..60917e7d3 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -106,7 +106,7 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 { data := linux.SeccompData{ Nr: sysno, - Arch: t.tc.st.AuditNumber, + Arch: t.image.st.AuditNumber, InstructionPointer: uint64(ip), } // data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go index 90f890495..0b17a562e 100644 --- a/pkg/sentry/kernel/syscalls_state.go +++ b/pkg/sentry/kernel/syscalls_state.go @@ -30,18 +30,18 @@ type syscallTableInfo struct { } // saveSt saves the SyscallTable. -func (tc *TaskContext) saveSt() syscallTableInfo { +func (image *TaskImage) saveSt() syscallTableInfo { return syscallTableInfo{ - OS: tc.st.OS, - Arch: tc.st.Arch, + OS: image.st.OS, + Arch: image.st.Arch, } } // loadSt loads the SyscallTable. -func (tc *TaskContext) loadSt(sti syscallTableInfo) { +func (image *TaskImage) loadSt(sti syscallTableInfo) { st, ok := LookupSyscallTable(sti.OS, sti.Arch) if !ok { panic(fmt.Sprintf("syscall table not found for OS %v, Arch %v", sti.OS, sti.Arch)) } - tc.st = st // Save the table reference. + image.st = st // Save the table reference. } diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index a83ce219c..3fee7aa68 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -75,6 +75,12 @@ func (s *syslog) Log() []byte { "Checking naughty and nice process list...", // Check it up to twice. "Granting licence to kill(2)...", // British spelling for British movie. "Letting the watchdogs out...", + "Conjuring /dev/null black hole...", + "Adversarially training Redcode AI...", + "Singleplexing /dev/ptmx...", + "Recruiting cron-ies...", + "Verifying that no non-zero bytes made their way into /dev/zero...", + "Accelerating teletypewriter to 9600 baud...", } selectMessage := func() string { diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 037971393..c0ab53c94 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" - "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -29,11 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -63,6 +58,12 @@ import ( type Task struct { taskNode + // goid is the task goroutine's ID. goid is owned by the task goroutine, + // but since it's used to detect cases where non-task goroutines + // incorrectly access state owned by, or exclusive to, the task goroutine, + // goid is always accessed using atomic memory operations. + goid int64 `state:"nosave"` + // runState is what the task goroutine is executing if it is not stopped. // If runState is nil, the task goroutine should exit or has exited. // runState is exclusive to the task goroutine. @@ -83,7 +84,7 @@ type Task struct { // taskWork is exclusive to the task goroutine. taskWork []TaskWorker - // haveSyscallReturn is true if tc.Arch().Return() represents a value + // haveSyscallReturn is true if image.Arch().Return() represents a value // returned by a syscall (or set by ptrace after a syscall). // // haveSyscallReturn is exclusive to the task goroutine. @@ -257,10 +258,10 @@ type Task struct { // mu protects some of the following fields. mu sync.Mutex `state:"nosave"` - // tc holds task data provided by the ELF loader. + // image holds task data provided by the ELF loader. // - // tc is protected by mu, and is owned by the task goroutine. - tc TaskContext + // image is protected by mu, and is owned by the task goroutine. + image TaskImage // fsContext is the task's filesystem context. // @@ -274,7 +275,7 @@ type Task struct { // If vforkParent is not nil, it is the task that created this task with // vfork() or clone(CLONE_VFORK), and should have its vforkStop ended when - // this TaskContext is released. + // this TaskImage is released. // // vforkParent is protected by the TaskSet mutex. vforkParent *Task @@ -641,64 +642,6 @@ func (t *Task) Kernel() *Kernel { return t.k } -// Value implements context.Context.Value. -// -// Preconditions: The caller must be running on the task goroutine (as implied -// by the requirements of context.Context). -func (t *Task) Value(key interface{}) interface{} { - switch key { - case CtxCanTrace: - return t.CanTrace - case CtxKernel: - return t.k - case CtxPIDNamespace: - return t.tg.pidns - case CtxUTSNamespace: - return t.utsns - case CtxIPCNamespace: - ipcns := t.IPCNamespace() - ipcns.IncRef() - return ipcns - case CtxTask: - return t - case auth.CtxCredentials: - return t.Credentials() - case context.CtxThreadGroupID: - return int32(t.ThreadGroup().ID()) - case fs.CtxRoot: - return t.fsContext.RootDirectory() - case vfs.CtxRoot: - return t.fsContext.RootDirectoryVFS2() - case vfs.CtxMountNamespace: - t.mountNamespaceVFS2.IncRef() - return t.mountNamespaceVFS2 - case fs.CtxDirentCacheLimiter: - return t.k.DirentCacheLimiter - case inet.CtxStack: - return t.NetworkContext() - case ktime.CtxRealtimeClock: - return t.k.RealtimeClock() - case limits.CtxLimits: - return t.tg.limits - case pgalloc.CtxMemoryFile: - return t.k.mf - case pgalloc.CtxMemoryFileProvider: - return t.k - case platform.CtxPlatform: - return t.k - case uniqueid.CtxGlobalUniqueID: - return t.k.UniqueID() - case uniqueid.CtxGlobalUniqueIDProvider: - return t.k - case uniqueid.CtxInotifyCookie: - return t.k.GenerateInotifyCookie() - case unimpl.CtxEvents: - return t.k - default: - return nil - } -} - // SetClearTID sets t's cleartid. // // Preconditions: The caller must be running on the task goroutine. @@ -751,12 +694,12 @@ func (t *Task) IsChrooted() bool { return root != realRoot } -// TaskContext returns t's TaskContext. +// TaskImage returns t's TaskImage. // // Precondition: The caller must be running on the task goroutine, or t.mu must // be locked. -func (t *Task) TaskContext() *TaskContext { - return &t.tc +func (t *Task) TaskImage() *TaskImage { + return &t.image } // FSContext returns t's FSContext. FSContext does not take an additional diff --git a/pkg/sentry/kernel/task_acct.go b/pkg/sentry/kernel/task_acct.go index 5f3e60fe8..e574997f7 100644 --- a/pkg/sentry/kernel/task_acct.go +++ b/pkg/sentry/kernel/task_acct.go @@ -136,14 +136,14 @@ func (tg *ThreadGroup) IOUsage() *usage.IO { func (t *Task) Name() string { t.mu.Lock() defer t.mu.Unlock() - return t.tc.Name + return t.image.Name } // SetName changes t's name. func (t *Task) SetName(name string) { t.mu.Lock() defer t.mu.Unlock() - t.tc.Name = name + t.image.Name = name t.Debugf("Set thread name to %q", name) } diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 4a4a69ee2..9419f2e95 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -20,6 +20,7 @@ import ( "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -32,6 +33,8 @@ import ( // // - An error which is nil if an event is received from C, ETIMEDOUT if the timeout // expired, and syserror.ErrInterrupted if t is interrupted. +// +// Preconditions: The caller must be running on the task goroutine. func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.Duration) (time.Duration, error) { if !haveTimeout { return timeout, t.block(C, nil) @@ -112,7 +115,14 @@ func (t *Task) Block(C <-chan struct{}) error { // block blocks a task on one of many events. // N.B. defer is too expensive to be used here. +// +// Preconditions: The caller must be running on the task goroutine. func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { + // This function is very hot; skip this check outside of +race builds. + if sync.RaceEnabled { + t.assertTaskGoroutine() + } + // Fast path if the request is already done. select { case <-C: @@ -156,33 +166,39 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { } } -// SleepStart implements amutex.Sleeper.SleepStart. +// SleepStart implements context.ChannelSleeper.SleepStart. func (t *Task) SleepStart() <-chan struct{} { + t.assertTaskGoroutine() t.Deactivate() t.accountTaskGoroutineEnter(TaskGoroutineBlockedInterruptible) return t.interruptChan } -// SleepFinish implements amutex.Sleeper.SleepFinish. +// SleepFinish implements context.ChannelSleeper.SleepFinish. func (t *Task) SleepFinish(success bool) { if !success { - // The interrupted notification is consumed only at the top-level - // (Run). Therefore we attempt to reset the pending notification. - // This will also elide our next entry back into the task, so we - // will process signals, state changes, etc. + // Our caller received from t.interruptChan; we need to re-send to it + // to ensure that t.interrupted() is still true. t.interruptSelf() } t.accountTaskGoroutineLeave(TaskGoroutineBlockedInterruptible) t.Activate() } -// Interrupted implements amutex.Sleeper.Interrupted +// Interrupted implements context.ChannelSleeper.Interrupted. func (t *Task) Interrupted() bool { - return len(t.interruptChan) != 0 + if t.interrupted() { + return true + } + // Indicate that t's task goroutine is still responsive (i.e. reset the + // watchdog timer). + t.accountTaskGoroutineRunning() + return false } // UninterruptibleSleepStart implements context.Context.UninterruptibleSleepStart. func (t *Task) UninterruptibleSleepStart(deactivate bool) { + t.assertTaskGoroutine() if deactivate { t.Deactivate() } @@ -198,13 +214,17 @@ func (t *Task) UninterruptibleSleepFinish(activate bool) { } // interrupted returns true if interrupt or interruptSelf has been called at -// least once since the last call to interrupted. +// least once since the last call to unsetInterrupted. func (t *Task) interrupted() bool { + return len(t.interruptChan) != 0 +} + +// unsetInterrupted causes interrupted to return false until the next call to +// interrupt or interruptSelf. +func (t *Task) unsetInterrupted() { select { case <-t.interruptChan: - return true default: - return false } } @@ -220,9 +240,7 @@ func (t *Task) interrupt() { func (t *Task) interruptSelf() { select { case t.interruptChan <- struct{}{}: - t.Debugf("Interrupt queued") default: - t.Debugf("Dropping duplicate interrupt") } // platform.Context.Interrupt() is unnecessary since a task goroutine // calling interruptSelf() cannot also be blocked in diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 527344162..f305e69c0 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -115,7 +115,7 @@ type CloneOptions struct { ParentTID usermem.Addr // If Vfork is true, place the parent in vforkStop until the cloned task - // releases its TaskContext. + // releases its TaskImage. Vfork bool // If Untraced is true, do not report PTRACE_EVENT_CLONE/FORK/VFORK for @@ -226,20 +226,20 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { }) } - tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace) + image, err := t.image.Fork(t, t.k, !opts.NewAddressSpace) if err != nil { return 0, nil, err } cu.Add(func() { - tc.release() + image.release() }) // clone() returns 0 in the child. - tc.Arch.SetReturn(0) + image.Arch.SetReturn(0) if opts.Stack != 0 { - tc.Arch.SetStack(uintptr(opts.Stack)) + image.Arch.SetStack(uintptr(opts.Stack)) } if opts.SetTLS { - if !tc.Arch.SetTLS(uintptr(opts.TLS)) { + if !image.Arch.SetTLS(uintptr(opts.TLS)) { return 0, nil, syserror.EPERM } } @@ -288,7 +288,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { Kernel: t.k, ThreadGroup: tg, SignalMask: t.SignalMask(), - TaskContext: tc, + TaskImage: image, FSContext: fsContext, FDTable: fdTable, Credentials: creds, diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index d1136461a..70b0699dc 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,159 +15,175 @@ package kernel import ( - "fmt" + "time" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/kernel/futex" - "gvisor.dev/gvisor/pkg/sentry/loader" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/unimpl" + "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" ) -var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC) - -// Auxmap contains miscellaneous data for the task. -type Auxmap map[string]interface{} - -// TaskContext is the subset of a task's data that is provided by the loader. -// -// +stateify savable -type TaskContext struct { - // Name is the thread name set by the prctl(PR_SET_NAME) system call. - Name string - - // Arch is the architecture-specific context (registers, etc.) - Arch arch.Context - - // MemoryManager is the task's address space. - MemoryManager *mm.MemoryManager +// Deadline implements context.Context.Deadline. +func (t *Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} - // fu implements futexes in the address space. - fu *futex.Manager +// Done implements context.Context.Done. +func (t *Task) Done() <-chan struct{} { + return nil +} - // st is the task's syscall table. - st *SyscallTable `state:".(syscallTableInfo)"` +// Err implements context.Context.Err. +func (t *Task) Err() error { + return nil } -// release releases all resources held by the TaskContext. release is called by -// the task when it execs into a new TaskContext or exits. -func (tc *TaskContext) release() { - // Nil out pointers so that if the task is saved after release, it doesn't - // follow the pointers to possibly now-invalid objects. - if tc.MemoryManager != nil { - tc.MemoryManager.DecUsers(context.Background()) - tc.MemoryManager = nil +// Value implements context.Context.Value. +// +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) Value(key interface{}) interface{} { + // This function is very hot; skip this check outside of +race builds. + if sync.RaceEnabled { + t.assertTaskGoroutine() } - tc.fu = nil + return t.contextValue(key, true /* isTaskGoroutine */) } -// Fork returns a duplicate of tc. The copied TaskContext always has an -// independent arch.Context. If shareAddressSpace is true, the copied -// TaskContext shares an address space with the original; otherwise, the copied -// TaskContext has an independent address space that is initially a duplicate -// of the original's. -func (tc *TaskContext) Fork(ctx context.Context, k *Kernel, shareAddressSpace bool) (*TaskContext, error) { - newTC := &TaskContext{ - Name: tc.Name, - Arch: tc.Arch.Fork(), - st: tc.st, - } - if shareAddressSpace { - newTC.MemoryManager = tc.MemoryManager - if newTC.MemoryManager != nil { - if !newTC.MemoryManager.IncUsers() { - // Shouldn't be possible since tc.MemoryManager should be a - // counted user. - panic(fmt.Sprintf("TaskContext.Fork called with userless TaskContext.MemoryManager")) - } +func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { + switch key { + case CtxCanTrace: + return t.CanTrace + case CtxKernel: + return t.k + case CtxPIDNamespace: + return t.tg.pidns + case CtxUTSNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + return t.utsns + case CtxIPCNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + ipcns := t.ipcns + ipcns.IncRef() + return ipcns + case CtxTask: + return t + case auth.CtxCredentials: + return t.creds.Load() + case context.CtxThreadGroupID: + return int32(t.tg.ID()) + case fs.CtxRoot: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + return t.fsContext.RootDirectory() + case vfs.CtxRoot: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() } - newTC.fu = tc.fu - } else { - newMM, err := tc.MemoryManager.Fork(ctx) - if err != nil { - return nil, err + return t.fsContext.RootDirectoryVFS2() + case vfs.CtxMountNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() } - newTC.MemoryManager = newMM - newTC.fu = k.futexes.Fork() + t.mountNamespaceVFS2.IncRef() + return t.mountNamespaceVFS2 + case fs.CtxDirentCacheLimiter: + return t.k.DirentCacheLimiter + case inet.CtxStack: + return t.NetworkContext() + case ktime.CtxRealtimeClock: + return t.k.RealtimeClock() + case limits.CtxLimits: + return t.tg.limits + case pgalloc.CtxMemoryFile: + return t.k.mf + case pgalloc.CtxMemoryFileProvider: + return t.k + case platform.CtxPlatform: + return t.k + case uniqueid.CtxGlobalUniqueID: + return t.k.UniqueID() + case uniqueid.CtxGlobalUniqueIDProvider: + return t.k + case uniqueid.CtxInotifyCookie: + return t.k.GenerateInotifyCookie() + case unimpl.CtxEvents: + return t.k + default: + return nil } - return newTC, nil } -// Arch returns t's arch.Context. -// -// Preconditions: The caller must be running on the task goroutine, or t.mu -// must be locked. -func (t *Task) Arch() arch.Context { - return t.tc.Arch +// taskAsyncContext implements context.Context for a goroutine that performs +// work on behalf of a Task, but is not the task goroutine. +type taskAsyncContext struct { + context.NoopSleeper + + t *Task } -// MemoryManager returns t's MemoryManager. MemoryManager does not take an -// additional reference on the returned MM. -// -// Preconditions: The caller must be running on the task goroutine, or t.mu -// must be locked. -func (t *Task) MemoryManager() *mm.MemoryManager { - return t.tc.MemoryManager +// AsyncContext returns a context.Context representing t. The returned +// context.Context is intended for use by goroutines other than t's task +// goroutine; for example, signal delivery to t will not interrupt goroutines +// that are blocking using the returned context.Context. +func (t *Task) AsyncContext() context.Context { + return taskAsyncContext{t: t} } -// SyscallTable returns t's syscall table. -// -// Preconditions: The caller must be running on the task goroutine, or t.mu -// must be locked. -func (t *Task) SyscallTable() *SyscallTable { - return t.tc.st +// Debugf implements log.Logger.Debugf. +func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) { + ctx.t.Debugf(format, v...) } -// Stack returns the userspace stack. -// -// Preconditions: The caller must be running on the task goroutine, or t.mu -// must be locked. -func (t *Task) Stack() *arch.Stack { - return &arch.Stack{ - Arch: t.Arch(), - IO: t.MemoryManager(), - Bottom: usermem.Addr(t.Arch().Stack()), - } +// Infof implements log.Logger.Infof. +func (ctx taskAsyncContext) Infof(format string, v ...interface{}) { + ctx.t.Infof(format, v...) } -// LoadTaskImage loads a specified file into a new TaskContext. -// -// args.MemoryManager does not need to be set by the caller. -func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) { - // If File is not nil, we should load that instead of resolving Filename. - if args.File != nil { - args.Filename = args.File.PathnameWithDeleted(ctx) - } +// Warningf implements log.Logger.Warningf. +func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) { + ctx.t.Warningf(format, v...) +} + +// IsLogging implements log.Logger.IsLogging. +func (ctx taskAsyncContext) IsLogging(level log.Level) bool { + return ctx.t.IsLogging(level) +} - // Prepare a new user address space to load into. - m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) - defer m.DecUsers(ctx) - args.MemoryManager = m +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return time.Time{}, false +} - os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) - if err != nil { - return nil, err - } +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return nil +} - // Lookup our new syscall table. - st, ok := LookupSyscallTable(os, ac.Arch()) - if !ok { - // No syscall table found. This means that the ELF binary does not match - // the architecture. - return nil, errNoSyscalls - } +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return nil +} - if !m.IncUsers() { - panic("Failed to increment users count on new MM") - } - return &TaskContext{ - Name: name, - Arch: ac, - MemoryManager: m, - fu: k.futexes.Fork(), - st: st, - }, nil +// Value implements context.Context.Value. +func (ctx taskAsyncContext) Value(key interface{}) interface{} { + return ctx.t.contextValue(key, false /* isTaskGoroutine */) } diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 412d471d3..d9897e802 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -83,11 +83,12 @@ type execStop struct{} func (*execStop) Killable() bool { return true } // Execve implements the execve(2) syscall by killing all other tasks in its -// thread group and switching to newTC. Execve always takes ownership of newTC. +// thread group and switching to newImage. Execve always takes ownership of +// newImage. // // Preconditions: The caller must be running Task.doSyscallInvoke on the task // goroutine. -func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) { +func (t *Task) Execve(newImage *TaskImage) (*SyscallControl, error) { t.tg.pidns.owner.mu.Lock() defer t.tg.pidns.owner.mu.Unlock() t.tg.signalHandlers.mu.Lock() @@ -96,7 +97,7 @@ func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) { if t.tg.exiting || t.tg.execing != nil { // We lost to a racing group-exit, kill, or exec from another thread // and should just exit. - newTC.release() + newImage.release() return nil, syserror.EINTR } @@ -118,7 +119,7 @@ func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) { t.beginInternalStopLocked((*execStop)(nil)) } - return &SyscallControl{next: &runSyscallAfterExecStop{newTC}, ignoreReturn: true}, nil + return &SyscallControl{next: &runSyscallAfterExecStop{newImage}, ignoreReturn: true}, nil } // The runSyscallAfterExecStop state continues execve(2) after all siblings of @@ -126,16 +127,16 @@ func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) { // // +stateify savable type runSyscallAfterExecStop struct { - tc *TaskContext + image *TaskImage } func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { - t.traceExecEvent(r.tc) + t.traceExecEvent(r.image) t.tg.pidns.owner.mu.Lock() t.tg.execing = nil if t.killed() { t.tg.pidns.owner.mu.Unlock() - r.tc.release() + r.image.release() return (*runInterrupt)(nil) } // We are the thread group leader now. Save our old thread ID for @@ -214,7 +215,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { // executables (set-user/group-ID bits and file capabilities). This // allows us to unconditionally enable user dumpability on the new mm. // See fs/exec.c:setup_new_exec. - r.tc.MemoryManager.SetDumpability(mm.UserDumpable) + r.image.MemoryManager.SetDumpability(mm.UserDumpable) // Switch to the new process. t.MemoryManager().Deactivate() @@ -222,8 +223,8 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { // Update credentials to reflect the execve. This should precede switching // MMs to ensure that dumpability has been reset first, if needed. t.updateCredsForExecLocked() - t.tc.release() - t.tc = *r.tc + t.image.release() + t.image = *r.image t.mu.Unlock() t.unstopVforkParent() t.p.FullStateChanged() diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index ce7b9641d..c5137c282 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -266,7 +266,7 @@ func (*runExitMain) execute(t *Task) taskRunState { t.updateRSSLocked() t.tg.pidns.owner.mu.Unlock() t.mu.Lock() - t.tc.release() + t.image.release() t.mu.Unlock() // Releasing the MM unblocks a blocked CLONE_VFORK parent. diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index c80391475..195c7da9b 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -26,7 +26,7 @@ import ( // Preconditions: The caller must be running on the task goroutine, or t.mu // must be locked. func (t *Task) Futex() *futex.Manager { - return t.tc.fu + return t.image.fu } // SwapUint32 implements futex.Target.SwapUint32. diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go new file mode 100644 index 000000000..ce5fbd299 --- /dev/null +++ b/pkg/sentry/kernel/task_image.go @@ -0,0 +1,173 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel/futex" + "gvisor.dev/gvisor/pkg/sentry/loader" + "gvisor.dev/gvisor/pkg/sentry/mm" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/usermem" +) + +var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC) + +// Auxmap contains miscellaneous data for the task. +type Auxmap map[string]interface{} + +// TaskImage is the subset of a task's data that is provided by the loader. +// +// +stateify savable +type TaskImage struct { + // Name is the thread name set by the prctl(PR_SET_NAME) system call. + Name string + + // Arch is the architecture-specific context (registers, etc.) + Arch arch.Context + + // MemoryManager is the task's address space. + MemoryManager *mm.MemoryManager + + // fu implements futexes in the address space. + fu *futex.Manager + + // st is the task's syscall table. + st *SyscallTable `state:".(syscallTableInfo)"` +} + +// release releases all resources held by the TaskImage. release is called by +// the task when it execs into a new TaskImage or exits. +func (image *TaskImage) release() { + // Nil out pointers so that if the task is saved after release, it doesn't + // follow the pointers to possibly now-invalid objects. + if image.MemoryManager != nil { + image.MemoryManager.DecUsers(context.Background()) + image.MemoryManager = nil + } + image.fu = nil +} + +// Fork returns a duplicate of image. The copied TaskImage always has an +// independent arch.Context. If shareAddressSpace is true, the copied +// TaskImage shares an address space with the original; otherwise, the copied +// TaskImage has an independent address space that is initially a duplicate +// of the original's. +func (image *TaskImage) Fork(ctx context.Context, k *Kernel, shareAddressSpace bool) (*TaskImage, error) { + newImage := &TaskImage{ + Name: image.Name, + Arch: image.Arch.Fork(), + st: image.st, + } + if shareAddressSpace { + newImage.MemoryManager = image.MemoryManager + if newImage.MemoryManager != nil { + if !newImage.MemoryManager.IncUsers() { + // Shouldn't be possible since image.MemoryManager should be a + // counted user. + panic(fmt.Sprintf("TaskImage.Fork called with userless TaskImage.MemoryManager")) + } + } + newImage.fu = image.fu + } else { + newMM, err := image.MemoryManager.Fork(ctx) + if err != nil { + return nil, err + } + newImage.MemoryManager = newMM + newImage.fu = k.futexes.Fork() + } + return newImage, nil +} + +// Arch returns t's arch.Context. +// +// Preconditions: The caller must be running on the task goroutine, or t.mu +// must be locked. +func (t *Task) Arch() arch.Context { + return t.image.Arch +} + +// MemoryManager returns t's MemoryManager. MemoryManager does not take an +// additional reference on the returned MM. +// +// Preconditions: The caller must be running on the task goroutine, or t.mu +// must be locked. +func (t *Task) MemoryManager() *mm.MemoryManager { + return t.image.MemoryManager +} + +// SyscallTable returns t's syscall table. +// +// Preconditions: The caller must be running on the task goroutine, or t.mu +// must be locked. +func (t *Task) SyscallTable() *SyscallTable { + return t.image.st +} + +// Stack returns the userspace stack. +// +// Preconditions: The caller must be running on the task goroutine, or t.mu +// must be locked. +func (t *Task) Stack() *arch.Stack { + return &arch.Stack{ + Arch: t.Arch(), + IO: t.MemoryManager(), + Bottom: usermem.Addr(t.Arch().Stack()), + } +} + +// LoadTaskImage loads a specified file into a new TaskImage. +// +// args.MemoryManager does not need to be set by the caller. +func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskImage, *syserr.Error) { + // If File is not nil, we should load that instead of resolving Filename. + if args.File != nil { + args.Filename = args.File.PathnameWithDeleted(ctx) + } + + // Prepare a new user address space to load into. + m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) + defer m.DecUsers(ctx) + args.MemoryManager = m + + os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) + if err != nil { + return nil, err + } + + // Lookup our new syscall table. + st, ok := LookupSyscallTable(os, ac.Arch()) + if !ok { + // No syscall table found. This means that the ELF binary does not match + // the architecture. + return nil, errNoSyscalls + } + + if !m.IncUsers() { + panic("Failed to increment users count on new MM") + } + return &TaskImage{ + Name: name, + Arch: ac, + MemoryManager: m, + fu: k.futexes.Fork(), + st: st, + }, nil +} diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index d23cea802..c70e5e6ce 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -19,6 +19,7 @@ import ( "runtime/trace" "sort" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/usermem" ) @@ -215,7 +216,7 @@ func (t *Task) rebuildTraceContext(tid ThreadID) { // arbitrarily large (in general it won't be, especially for cases // where we're collecting a brief profile), so using the TID is a // reasonable compromise in this case. - t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid)) + t.traceContext, t.traceTask = trace.NewTask(context.Background(), fmt.Sprintf("tid:%d", tid)) } // traceCloneEvent is called when a new task is spawned. @@ -237,11 +238,11 @@ func (t *Task) traceExitEvent() { } // traceExecEvent is called when a task calls exec. -func (t *Task) traceExecEvent(tc *TaskContext) { +func (t *Task) traceExecEvent(image *TaskImage) { if !trace.IsEnabled() { return } - file := tc.MemoryManager.Executable() + file := image.MemoryManager.Executable() if file == nil { trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>") return diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 8dc3fec90..3ccecf4b6 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -16,11 +16,13 @@ package kernel import ( "bytes" + "fmt" "runtime" "runtime/trace" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/goid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/hostcpu" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -57,6 +59,8 @@ type taskRunState interface { // make it visible in stack dumps. A goroutine for a given task can be identified // searching for Task.run()'s argument value. func (t *Task) run(threadID uintptr) { + atomic.StoreInt64(&t.goid, goid.Get()) + // Construct t.blockingTimer here. We do this here because we can't // reconstruct t.blockingTimer during restore in Task.afterLoad(), because // kernel.timekeeper.SetClocks() hasn't been called yet. @@ -99,6 +103,9 @@ func (t *Task) run(threadID uintptr) { t.tg.pidns.owner.runningGoroutines.Done() t.p.Release() + // Deferring this store triggers a false positive in the race + // detector (https://github.com/golang/go/issues/42599). + atomic.StoreInt64(&t.goid, 0) // Keep argument alive because stack trace for dead variables may not be correct. runtime.KeepAlive(threadID) return @@ -317,7 +324,7 @@ func (app *runApp) execute(t *Task) taskRunState { // region. We should be able to easily identify // vsyscalls by having a <fault><syscall> pair. if at.Execute { - if sysno, ok := t.tc.st.LookupEmulate(addr); ok { + if sysno, ok := t.image.st.LookupEmulate(addr); ok { return t.doVsyscall(addr, sysno) } } @@ -375,6 +382,19 @@ func (app *runApp) execute(t *Task) taskRunState { } } +// assertTaskGoroutine panics if the caller is not running on t's task +// goroutine. +func (t *Task) assertTaskGoroutine() { + if got, want := goid.Get(), atomic.LoadInt64(&t.goid); got != want { + panic(fmt.Sprintf("running on goroutine %d (task goroutine for kernel.Task %p is %d)", got, t, want)) + } +} + +// GoroutineID returns the ID of t's task goroutine. +func (t *Task) GoroutineID() int64 { + return atomic.LoadInt64(&t.goid) +} + // waitGoroutineStoppedOrExited blocks until t's task goroutine stops or exits. func (t *Task) waitGoroutineStoppedOrExited() { t.goroutineStopped.Wait() diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 52c55d13d..9ba5f8d78 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -157,6 +157,18 @@ func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) { t.goschedSeq.EndWrite() } +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) accountTaskGoroutineRunning() { + now := t.k.CPUClockNow() + if t.gosched.State != TaskGoroutineRunningSys { + panic(fmt.Sprintf("Task goroutine in state %v (expected %v)", t.gosched.State, TaskGoroutineRunningSys)) + } + t.goschedSeq.BeginWrite() + t.gosched.SysTicks += now - t.gosched.Timestamp + t.gosched.Timestamp = now + t.goschedSeq.EndWrite() +} + // TaskGoroutineSchedInfo returns a copy of t's task goroutine scheduling info. // Most clients should use t.CPUStats() instead. func (t *Task) TaskGoroutineSchedInfo() TaskGoroutineSchedInfo { diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index ebdb83061..42dd3e278 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -619,9 +619,6 @@ func (t *Task) setSignalMaskLocked(mask linux.SignalSet) { return } }) - // We have to re-issue the interrupt consumed by t.interrupted() since - // it might have been for a different reason. - t.interruptSelf() } // Conversely, if the new mask unblocks any signals that were blocked by @@ -931,10 +928,10 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { type runInterrupt struct{} func (*runInterrupt) execute(t *Task) taskRunState { - // Interrupts are de-duplicated (if t is interrupted twice before - // t.interrupted() is called, t.interrupted() will only return true once), - // so early exits from this function must re-enter the runInterrupt state - // to check for more interrupt-signaled conditions. + // Interrupts are de-duplicated (t.unsetInterrupted() will undo the effect + // of all previous calls to t.interrupted() regardless of how many such + // calls there have been), so early exits from this function must re-enter + // the runInterrupt state to check for more interrupt-signaled conditions. t.tg.signalHandlers.mu.Lock() @@ -1080,6 +1077,7 @@ func (*runInterrupt) execute(t *Task) taskRunState { return t.deliverSignal(info, act) } + t.unsetInterrupted() t.tg.signalHandlers.mu.Unlock() return (*runApp)(nil) } diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 8e28230cc..36e1384f1 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -46,10 +46,10 @@ type TaskConfig struct { // SignalMask is the new task's initial signal mask. SignalMask linux.SignalSet - // TaskContext is the TaskContext of the new task. Ownership of the - // TaskContext is transferred to TaskSet.NewTask, whether or not it + // TaskImage is the TaskImage of the new task. Ownership of the + // TaskImage is transferred to TaskSet.NewTask, whether or not it // succeeds. - TaskContext *TaskContext + TaskImage *TaskImage // FSContext is the FSContext of the new task. A reference must be held on // FSContext, which is transferred to TaskSet.NewTask whether or not it @@ -105,7 +105,7 @@ type TaskConfig struct { func (ts *TaskSet) NewTask(ctx context.Context, cfg *TaskConfig) (*Task, error) { t, err := ts.newTask(cfg) if err != nil { - cfg.TaskContext.release() + cfg.TaskImage.release() cfg.FSContext.DecRef(ctx) cfg.FDTable.DecRef(ctx) cfg.IPCNamespace.DecRef(ctx) @@ -121,7 +121,7 @@ func (ts *TaskSet) NewTask(ctx context.Context, cfg *TaskConfig) (*Task, error) // of cfg if it succeeds. func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { tg := cfg.ThreadGroup - tc := cfg.TaskContext + image := cfg.TaskImage t := &Task{ taskNode: taskNode{ tg: tg, @@ -132,7 +132,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { interruptChan: make(chan struct{}, 1), signalMask: cfg.SignalMask, signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable}, - tc: *tc, + image: *image, fsContext: cfg.FSContext, fdTable: cfg.FDTable, p: cfg.Kernel.Platform.NewContext(), diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index fd92c3873..3f5be276b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -263,13 +263,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return usermem.NoAccess, platform.ErrContextInterrupt case ring0.El0SyncUndef: return c.fault(int32(syscall.SIGILL), info) - case ring0.El1SyncUndef: - *info = arch.SignalInfo{ - Signo: int32(syscall.SIGILL), - Code: 1, // ILL_ILLOPC (illegal opcode). - } - info.SetAddr(switchOpts.Registers.Pc) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal default: panic(fmt.Sprintf("unexpected vector: 0x%x", vector)) } diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index 327d48465..c51df2811 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -90,6 +90,7 @@ const ( El0SyncIa El0SyncFpsimdAcc El0SyncSveAcc + El0SyncFpsimdExc El0SyncSys El0SyncSpPc El0SyncUndef diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index f489ad352..5f4b2105a 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -324,6 +324,18 @@ MOVD CPU_TTBR0_KVM(from), RSV_REG; \ MSR RSV_REG, TTBR0_EL1; +TEXT ·EnableVFP(SB),NOSPLIT,$0 + MOVD $FPEN_ENABLE, R0 + WORD $0xd5181040 //MSR R0, CPACR_EL1 + ISB $15 + RET + +TEXT ·DisableVFP(SB),NOSPLIT,$0 + MOVD $0, R0 + WORD $0xd5181040 //MSR R0, CPACR_EL1 + ISB $15 + RET + #define VFP_ENABLE \ MOVD $FPEN_ENABLE, R0; \ WORD $0xd5181040; \ //MSR R0, CPACR_EL1 @@ -370,12 +382,12 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. -// EXCEPTION_WITH_ERROR is a common exception handler function. -#define EXCEPTION_WITH_ERROR(user, vector) \ +// EXCEPTION_EL0 is a common el0 exception handler function. +#define EXCEPTION_EL0(vector) \ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 WORD $0xd538601a; \ //MRS FAR_EL1, R26 MOVD R26, CPU_FAULT_ADDR(RSV_REG); \ - MOVD $user, R3; \ + MOVD $1, R3; \ MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user. MOVD $vector, R3; \ MOVD R3, CPU_VECTOR_CODE(RSV_REG); \ @@ -383,6 +395,12 @@ MOVD R3, CPU_ERROR_CODE(RSV_REG); \ B ·kernelExitToEl1(SB); +// EXCEPTION_EL1 is a common el1 exception handler function. +#define EXCEPTION_EL1(vector) \ + MOVD $vector, R3; \ + MOVD R3, 8(RSP); \ + B ·HaltEl1ExceptionAndResume(SB); + // storeAppASID writes the application's asid value. TEXT ·storeAppASID(SB),NOSPLIT,$0-8 MOVD asid+0(FP), R1 @@ -430,6 +448,16 @@ TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0 CALL ·kernelSyscall(SB) // Call the trampoline. B ·kernelExitToEl1(SB) // Resume. +// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume. +TEXT ·HaltEl1ExceptionAndResume(SB),NOSPLIT,$0-8 + WORD $0xd538d092 // MRS TPIDR_EL1, R18 + MOVD CPU_SELF(RSV_REG), R3 // Load vCPU. + MOVD R3, 8(RSP) // First argument (vCPU). + MOVD vector+0(FP), R3 + MOVD R3, 16(RSP) // Second argument (vector). + CALL ·kernelException(SB) // Call the trampoline. + B ·kernelExitToEl1(SB) // Resume. + // Shutdown stops the guest. TEXT ·Shutdown(SB),NOSPLIT,$0 // PSCI EVENT. @@ -592,39 +620,22 @@ TEXT ·El1_sync(SB),NOSPLIT,$0 B el1_invalid el1_da: + EXCEPTION_EL1(El1SyncDa) el1_ia: - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - WORD $0xd538601a //MRS FAR_EL1, R26 - - MOVD R26, CPU_FAULT_ADDR(RSV_REG) - - MOVD $0, CPU_ERROR_TYPE(RSV_REG) - - MOVD $PageFault, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - B ·HaltAndResume(SB) - + EXCEPTION_EL1(El1SyncIa) el1_sp_pc: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncSpPc) el1_undef: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncUndef) el1_svc: - MOVD $0, CPU_ERROR_CODE(RSV_REG) - MOVD $0, CPU_ERROR_TYPE(RSV_REG) B ·HaltEl1SvcAndResume(SB) - el1_dbg: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncDbg) el1_fpsimd_acc: VFP_ENABLE B ·kernelExitToEl1(SB) // Resume. - el1_invalid: - B ·Shutdown(SB) + EXCEPTION_EL1(El1SyncInv) // El1_irq is the handler for El1_irq. TEXT ·El1_irq(SB),NOSPLIT,$0 @@ -680,28 +691,21 @@ el0_svc: el0_da: el0_ia: - EXCEPTION_WITH_ERROR(1, PageFault) - + EXCEPTION_EL0(PageFault) el0_fpsimd_acc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncFpsimdAcc) el0_sve_acc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncSveAcc) el0_fpsimd_exc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncFpsimdExc) el0_sp_pc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncSpPc) el0_undef: - EXCEPTION_WITH_ERROR(1, El0SyncUndef) - + EXCEPTION_EL0(El0SyncUndef) el0_dbg: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncDbg) el0_invalid: - B ·Shutdown(SB) + EXCEPTION_EL0(El0SyncInv) TEXT ·El0_irq(SB),NOSPLIT,$0 B ·Shutdown(SB) diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 6cbbf001f..90a7b8392 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -24,6 +24,10 @@ func HaltAndResume() //go:nosplit func HaltEl1SvcAndResume() +// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume. +//go:nosplit +func HaltEl1ExceptionAndResume() + // init initializes architecture-specific state. func (k *Kernel) init(maxCPUs int) { } @@ -61,11 +65,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Pstate &= ^uint64(PsrFlagsClear) regs.Pstate |= UserFlagsSet + EnableVFP() LoadFloatingPoint(switchOpts.FloatingPointState) kernelExitToEl0() SaveFloatingPoint(switchOpts.FloatingPointState) + DisableVFP() vector = c.vecCode diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index d91a09de1..0dffd33a3 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -53,6 +53,12 @@ func LoadFloatingPoint(*byte) // SaveFloatingPoint saves floating point state. func SaveFloatingPoint(*byte) +// EnableVFP enables fpsimd. +func EnableVFP() + +// DisableVFP disables fpsimd. +func DisableVFP() + // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 19c1fca8b..6f4923539 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -35,62 +35,47 @@ TEXT ·CPACREL1(SB),NOSPLIT,$0-8 RET TEXT ·GetFPCR(SB),NOSPLIT,$0-8 - WORD $0xd53b4201 // MRS NZCV, R1 + MOVD FPCR, R1 MOVD R1, ret+0(FP) RET TEXT ·GetFPSR(SB),NOSPLIT,$0-8 - WORD $0xd53b4421 // MRS FPSR, R1 + MOVD FPSR, R1 MOVD R1, ret+0(FP) RET TEXT ·SetFPCR(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R1 - WORD $0xd51b4201 // MSR R1, NZCV + MOVD R1, FPCR RET TEXT ·SetFPSR(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R1 - WORD $0xd51b4421 // MSR R1, FPSR + MOVD R1, FPSR RET TEXT ·SaveVRegs(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R0 // Skip aarch64_ctx, fpsr, fpcr. - FMOVD F0, 16*1(R0) - FMOVD F1, 16*2(R0) - FMOVD F2, 16*3(R0) - FMOVD F3, 16*4(R0) - FMOVD F4, 16*5(R0) - FMOVD F5, 16*6(R0) - FMOVD F6, 16*7(R0) - FMOVD F7, 16*8(R0) - FMOVD F8, 16*9(R0) - FMOVD F9, 16*10(R0) - FMOVD F10, 16*11(R0) - FMOVD F11, 16*12(R0) - FMOVD F12, 16*13(R0) - FMOVD F13, 16*14(R0) - FMOVD F14, 16*15(R0) - FMOVD F15, 16*16(R0) - FMOVD F16, 16*17(R0) - FMOVD F17, 16*18(R0) - FMOVD F18, 16*19(R0) - FMOVD F19, 16*20(R0) - FMOVD F20, 16*21(R0) - FMOVD F21, 16*22(R0) - FMOVD F22, 16*23(R0) - FMOVD F23, 16*24(R0) - FMOVD F24, 16*25(R0) - FMOVD F25, 16*26(R0) - FMOVD F26, 16*27(R0) - FMOVD F27, 16*28(R0) - FMOVD F28, 16*29(R0) - FMOVD F29, 16*30(R0) - FMOVD F30, 16*31(R0) - FMOVD F31, 16*32(R0) - ISB $15 + ADD $16, R0, R0 + + WORD $0xad000400 // stp q0, q1, [x0] + WORD $0xad010c02 // stp q2, q3, [x0, #32] + WORD $0xad021404 // stp q4, q5, [x0, #64] + WORD $0xad031c06 // stp q6, q7, [x0, #96] + WORD $0xad042408 // stp q8, q9, [x0, #128] + WORD $0xad052c0a // stp q10, q11, [x0, #160] + WORD $0xad06340c // stp q12, q13, [x0, #192] + WORD $0xad073c0e // stp q14, q15, [x0, #224] + WORD $0xad084410 // stp q16, q17, [x0, #256] + WORD $0xad094c12 // stp q18, q19, [x0, #288] + WORD $0xad0a5414 // stp q20, q21, [x0, #320] + WORD $0xad0b5c16 // stp q22, q23, [x0, #352] + WORD $0xad0c6418 // stp q24, q25, [x0, #384] + WORD $0xad0d6c1a // stp q26, q27, [x0, #416] + WORD $0xad0e741c // stp q28, q29, [x0, #448] + WORD $0xad0f7c1e // stp q30, q31, [x0, #480] RET @@ -98,39 +83,24 @@ TEXT ·LoadVRegs(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R0 // Skip aarch64_ctx, fpsr, fpcr. - FMOVD 16*1(R0), F0 - FMOVD 16*2(R0), F1 - FMOVD 16*3(R0), F2 - FMOVD 16*4(R0), F3 - FMOVD 16*5(R0), F4 - FMOVD 16*6(R0), F5 - FMOVD 16*7(R0), F6 - FMOVD 16*8(R0), F7 - FMOVD 16*9(R0), F8 - FMOVD 16*10(R0), F9 - FMOVD 16*11(R0), F10 - FMOVD 16*12(R0), F11 - FMOVD 16*13(R0), F12 - FMOVD 16*14(R0), F13 - FMOVD 16*15(R0), F14 - FMOVD 16*16(R0), F15 - FMOVD 16*17(R0), F16 - FMOVD 16*18(R0), F17 - FMOVD 16*19(R0), F18 - FMOVD 16*20(R0), F19 - FMOVD 16*21(R0), F20 - FMOVD 16*22(R0), F21 - FMOVD 16*23(R0), F22 - FMOVD 16*24(R0), F23 - FMOVD 16*25(R0), F24 - FMOVD 16*26(R0), F25 - FMOVD 16*27(R0), F26 - FMOVD 16*28(R0), F27 - FMOVD 16*29(R0), F28 - FMOVD 16*30(R0), F29 - FMOVD 16*31(R0), F30 - FMOVD 16*32(R0), F31 - ISB $15 + ADD $16, R0, R0 + + WORD $0xad400400 // ldp q0, q1, [x0] + WORD $0xad410c02 // ldp q2, q3, [x0, #32] + WORD $0xad421404 // ldp q4, q5, [x0, #64] + WORD $0xad431c06 // ldp q6, q7, [x0, #96] + WORD $0xad442408 // ldp q8, q9, [x0, #128] + WORD $0xad452c0a // ldp q10, q11, [x0, #160] + WORD $0xad46340c // ldp q12, q13, [x0, #192] + WORD $0xad473c0e // ldp q14, q15, [x0, #224] + WORD $0xad484410 // ldp q16, q17, [x0, #256] + WORD $0xad494c12 // ldp q18, q19, [x0, #288] + WORD $0xad4a5414 // ldp q20, q21, [x0, #320] + WORD $0xad4b5c16 // ldp q22, q23, [x0, #352] + WORD $0xad4c6418 // ldp q24, q25, [x0, #384] + WORD $0xad4d6c1a // ldp q26, q27, [x0, #416] + WORD $0xad4e741c // ldp q28, q29, [x0, #448] + WORD $0xad4f7c1e // ldp q30, q31, [x0, #480] RET diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index 53bc3353c..b5652deb9 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -70,6 +70,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa) fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc) fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc) + fmt.Fprintf(w, "#define El0SyncFpsimdExc 0x%02x\n", El0SyncFpsimdExc) fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys) fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc) fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 3baad098b..057f4d294 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -120,9 +120,6 @@ type socketOpsCommon struct { // fixed buffer but only consume this many bytes. sendBufferSize uint32 - // passcred indicates if this socket wants SCM credentials. - passcred bool - // filter indicates that this socket has a BPF filter "installed". // // TODO(gvisor.dev/issue/1119): We don't actually support filtering, @@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { // Passcred implements transport.Credentialer.Passcred. func (s *socketOpsCommon) Passcred() bool { - s.mu.Lock() - passcred := s.passcred - s.mu.Unlock() - return passcred + return s.ep.SocketOptions().GetPassCred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. @@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] } passcred := usermem.ByteOrder.Uint32(opt) - s.mu.Lock() - s.passcred = passcred != 0 - s.mu.Unlock() + s.ep.SocketOptions().SetPassCred(passcred != 0) return nil case linux.SO_ATTACH_FILTER: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 86c634715..9c927efa0 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -260,8 +260,13 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) - // LastError implements tcpip.Endpoint.LastError. + // LastError implements tcpip.Endpoint.LastError and + // transport.Endpoint.LastError. LastError() *tcpip.Error + + // SocketOptions implements tcpip.Endpoint.SocketOptions and + // transport.Endpoint.SocketOptions. + SocketOptions() *tcpip.SocketOptions } // LINT.IfChange @@ -1065,13 +1070,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.PasscredOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred())) + return &v, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1163,13 +1163,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.BroadcastOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetBroadcast())) + return &v, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1916,7 +1911,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0)) + ep.SocketOptions().SetBroadcast(v != 0) + return nil case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { @@ -1924,7 +1920,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) + ep.SocketOptions().SetPassCred(v != 0) + return nil case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -2687,7 +2684,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ // Always do at least one fetchReadView, even if the number of bytes to // read is 0. err = s.fetchReadView() - if err != nil { + if err != nil || len(s.readView) == 0 { break } if dst.NumBytes() == 0 { @@ -2710,15 +2707,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ } copied += n s.readView.TrimFront(n) - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) - } dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) break } + // If we are done reading requested data then stop. + if dst.NumBytes() == 0 { + break + } + } + + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) } // If we managed to copy something, we must deliver it. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index b648273a4..0324dcd93 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,8 +16,6 @@ package transport import ( - "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -203,8 +201,12 @@ type Endpoint interface { // procfs. State() uint32 - // LastError implements tcpip.Endpoint.LastError. + // LastError clears and returns the last error reported by the endpoint. LastError() *tcpip.Error + + // SocketOptions returns the structure which contains all the socket + // level options. + SocketOptions() *tcpip.SocketOptions } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket @@ -737,10 +739,6 @@ func (e *connectedEndpoint) CloseUnread() { type baseEndpoint struct { *waiter.Queue - // passcred specifies whether SCM_CREDENTIALS socket control messages are - // enabled on this endpoint. Must be accessed atomically. - passcred int32 - // Mutex protects the below fields. sync.Mutex `state:"nosave"` @@ -757,6 +755,8 @@ type baseEndpoint struct { // linger is used for SO_LINGER socket option. linger tcpip.LingerOption + + ops tcpip.SocketOptions } // EventRegister implements waiter.Waitable.EventRegister. @@ -781,7 +781,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { // Passcred implements Credentialer.Passcred. func (e *baseEndpoint) Passcred() bool { - return atomic.LoadInt32(&e.passcred) != 0 + return e.SocketOptions().GetPassCred() } // ConnectedPasscred implements Credentialer.ConnectedPasscred. @@ -791,14 +791,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool { return e.connected != nil && e.connected.Passcred() } -func (e *baseEndpoint) setPasscred(pc bool) { - if pc { - atomic.StoreInt32(&e.passcred, 1) - } else { - atomic.StoreInt32(&e.passcred, 0) - } -} - // Connected implements ConnectingEndpoint.Connected. func (e *baseEndpoint) Connected() bool { return e.receiver != nil && e.connected != nil @@ -865,9 +857,6 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { - case tcpip.BroadcastOption: - case tcpip.PasscredOption: - e.setPasscred(v) case tcpip.ReuseAddressOption: default: log.Warningf("Unsupported socket option: %d", opt) @@ -890,9 +879,6 @@ func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil - case tcpip.PasscredOption: - return e.Passcred(), nil - default: log.Warningf("Unsupported socket option: %d", opt) return false, tcpip.ErrUnknownProtocolOption @@ -980,6 +966,11 @@ func (*baseEndpoint) LastError() *tcpip.Error { return nil } +// SocketOptions implements Endpoint.SocketOptions. +func (e *baseEndpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} + // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error { diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index 309c183a3..88cd234d1 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -90,6 +90,9 @@ var setableLimits = map[limits.LimitType]struct{}{ limits.FileSize: {}, limits.MemoryLocked: {}, limits.Stack: {}, + // RSS can be set, but it's not enforced because Linux doesn't enforce it + // either: "This limit has effect only in Linux 2.4.x, x < 30" + limits.Rss: {}, // These are not enforced, but we include them here to avoid returning // EPERM, since some apps expect them to succeed. limits.Core: {}, diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 39ca9ea97..983f8d396 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -159,7 +159,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user defer wd.DecRef(t) } - // Load the new TaskContext. + // Load the new TaskImage. remainingTraversals := uint(linux.MaxSymlinkTraversals) loadArgs := loader.LoadArgs{ Opener: fsbridge.NewFSLookup(t.MountNamespace(), root, wd), @@ -173,12 +173,12 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user Features: t.Arch().FeatureSet(), } - tc, se := t.Kernel().LoadTaskImage(t, loadArgs) + image, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } - ctrl, err := t.Execve(tc) + ctrl, err := t.Execve(image) return 0, ctrl, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go index c8ce2aabc..7a409620d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/execve.go +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -109,7 +109,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user executable = fsbridge.NewVFSFile(file) } - // Load the new TaskContext. + // Load the new TaskImage. mntns := t.MountNamespaceVFS2() wd := t.FSContext().WorkingDirectoryVFS2() defer wd.DecRef(t) @@ -126,11 +126,11 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user Features: t.Arch().FeatureSet(), } - tc, se := t.Kernel().LoadTaskImage(t, loadArgs) + image, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } - ctrl, err := t.Execve(tc) + ctrl, err := t.Execve(image) return 0, ctrl, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 9ce4f280a..8bb763a47 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -343,8 +343,8 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Copy data. var ( - n int64 - err error + total int64 + err error ) dw := dualWaiter{ inFile: inFile, @@ -357,13 +357,20 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // can block. nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0 if outIsPipe { - for n < count { - var spliceN int64 - spliceN, err = outPipeFD.SpliceFromNonPipe(t, inFile, offset, count) + for { + var n int64 + n, err = outPipeFD.SpliceFromNonPipe(t, inFile, offset, count-total) if offset != -1 { - offset += spliceN + offset += n + } + total += n + if total == count { + break + } + if err == nil && t.Interrupted() { + err = syserror.ErrInterrupted + break } - n += spliceN if err == syserror.ErrWouldBlock && !nonBlock { err = dw.waitForBoth(t) } @@ -374,7 +381,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } else { // Read inFile to buffer, then write the contents to outFile. buf := make([]byte, count) - for n < count { + for { var readN int64 if offset != -1 { readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{}) @@ -382,7 +389,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } else { readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) } - n += readN // Write all of the bytes that we read. This may need // multiple write calls to complete. @@ -398,7 +404,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // We didn't complete the write. Only report the bytes that were actually // written, and rewind offsets as needed. notWritten := int64(len(wbuf)) - n -= notWritten + readN -= notWritten if offset == -1 { // We modified the offset of the input file itself during the read // operation. Rewind it. @@ -415,6 +421,16 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc break } } + + total += readN + buf = buf[readN:] + if total == count { + break + } + if err == nil && t.Interrupted() { + err = syserror.ErrInterrupted + break + } if err == syserror.ErrWouldBlock && !nonBlock { err = dw.waitForBoth(t) } @@ -432,7 +448,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } - if n != 0 { + if total != 0 { inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) @@ -445,7 +461,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // We can only pass a single file to handleIOError, so pick inFile arbitrarily. // This is used only for debugging purposes. - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "sendfile", inFile) + return uintptr(total), nil, slinux.HandleIOErrorVFS2(t, total != 0, err, syserror.ERESTARTSYS, "sendfile", inFile) } // dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index bbafb8b7f..1d1062aeb 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -336,11 +336,9 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo buf.WriteString(fmt.Sprintf("Sentry detected %d stuck task(s):\n", len(offenders))) for t, o := range offenders { tid := w.k.TaskSet().Root.IDOfTask(t) - buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime))) + buf.WriteString(fmt.Sprintf("\tTask tid: %v (goroutine %d), entered RunSys state %v ago.\n", tid, t.GoroutineID(), now.Sub(o.lastUpdateTime))) } - buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine") - // Force stack dump only if a new task is detected. w.doAction(w.TaskTimeoutAction, newTaskFound, &buf) } diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go index e7c9640ba..aedaf5ee5 100644 --- a/pkg/shim/runsc/runsc.go +++ b/pkg/shim/runsc/runsc.go @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package runsc provides an API to interact with runsc command line. package runsc import ( @@ -33,12 +34,32 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" ) -// Monitor is the default process monitor to be used by runsc. -var Monitor runc.ProcessMonitor = runc.Monitor - // DefaultCommand is the default command for Runsc. const DefaultCommand = "runsc" +// Monitor is the default process monitor to be used by runsc. +var Monitor runc.ProcessMonitor = &LogMonitor{Next: runc.Monitor} + +// LogMonitor implements the runc.ProcessMonitor interface, logging the command +// that is getting executed, and then forwarding the call to another +// implementation. +type LogMonitor struct { + Next runc.ProcessMonitor +} + +// Start implements runc.ProcessMonitor. +func (l *LogMonitor) Start(cmd *exec.Cmd) (chan runc.Exit, error) { + log.L.Debugf("Executing: %s", cmd.Args) + return l.Next.Start(cmd) +} + +// Wait implements runc.ProcessMonitor. +func (l *LogMonitor) Wait(cmd *exec.Cmd, ch chan runc.Exit) (int, error) { + status, err := l.Next.Wait(cmd, ch) + log.L.Debugf("Command exit code: %d, err: %v", status, err) + return status, err +} + // Runsc is the client to the runsc cli. type Runsc struct { Command string @@ -370,9 +391,10 @@ func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { }() var e runc.Event if err := json.NewDecoder(rd).Decode(&e); err != nil { + log.L.Debugf("Parsing events error: %v", err) return nil, err } - log.L.Debugf("Stats returned: %+v", e.Stats) + log.L.Debugf("Stats returned, type: %s, stats: %+v", e.Type, e.Stats) if e.Type != "stats" { return nil, fmt.Errorf(`unexpected event type %q, wanted "stats"`, e.Type) } diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go index c514b3bc7..55f17d29e 100644 --- a/pkg/shim/runsc/utils.go +++ b/pkg/shim/runsc/utils.go @@ -36,9 +36,20 @@ func putBuf(b *bytes.Buffer) { bytesBufferPool.Put(b) } -// FormatLogPath parses runsc config, and fill in %ID% in the log path. -func FormatLogPath(id string, config map[string]string) { +// FormatRunscLogPath parses runsc config, and fill in %ID% in the log path. +func FormatRunscLogPath(id string, config map[string]string) { if path, ok := config["debug-log"]; ok { config["debug-log"] = strings.Replace(path, "%ID%", id, -1) } } + +// FormatShimLogPath creates the file path to the log file. It replaces %ID% +// in the path with the provided "id". It also uses a default log name if the +// path end with '/'. +func FormatShimLogPath(path string, id string) string { + if strings.HasSuffix(path, "/") { + // Default format: <path>/runsc-shim-<ID>.log + path += "runsc-shim-%ID%.log" + } + return strings.Replace(path, "%ID%", id, -1) +} diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go index dab3123d6..9fd7d978c 100644 --- a/pkg/shim/v1/proc/init.go +++ b/pkg/shim/v1/proc/init.go @@ -397,7 +397,7 @@ func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Pr } // exec returns a new exec'd process. -func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { +func (p *Init) exec(path string, r *ExecConfig) (process.Process, error) { // process exec request var spec specs.Process if err := json.Unmarshal(r.Spec.Value, &spec); err != nil { diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go index 9233ecc85..0065fc385 100644 --- a/pkg/shim/v1/proc/init_state.go +++ b/pkg/shim/v1/proc/init_state.go @@ -95,7 +95,7 @@ func (s *createdState) SetExited(status int) { } func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { - return s.p.exec(ctx, path, r) + return s.p.exec(path, r) } type runningState struct { @@ -137,7 +137,7 @@ func (s *runningState) SetExited(status int) { } func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { - return s.p.exec(ctx, path, r) + return s.p.exec(path, r) } type stoppedState struct { diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go index 2b0df4663..fc182cf5e 100644 --- a/pkg/shim/v1/proc/types.go +++ b/pkg/shim/v1/proc/types.go @@ -40,7 +40,6 @@ type CreateConfig struct { Stdin string Stdout string Stderr string - Options *types.Any } // ExecConfig holds exec creation configuration. diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go index 716de2f59..7c2c409af 100644 --- a/pkg/shim/v1/proc/utils.go +++ b/pkg/shim/v1/proc/utils.go @@ -67,24 +67,6 @@ func getLastRuntimeError(r *runsc.Runsc) (string, error) { return errMsg, nil } -func copyFile(to, from string) error { - ff, err := os.Open(from) - if err != nil { - return err - } - defer ff.Close() - tt, err := os.Create(to) - if err != nil { - return err - } - defer tt.Close() - - p := bufPool.Get().(*[]byte) - defer bufPool.Put(p) - _, err = io.CopyBuffer(tt, ff, *p) - return err -} - func hasNoIO(r *CreateConfig) bool { return r.Stdin == "" && r.Stdout == "" && r.Stderr == "" } diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go index 84a810cb2..80aa59b33 100644 --- a/pkg/shim/v1/shim/service.go +++ b/pkg/shim/v1/shim/service.go @@ -130,7 +130,6 @@ func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shi Stdin: r.Stdin, Stdout: r.Stdout, Stderr: r.Stderr, - Options: r.Options, } defer func() { if err != nil { @@ -150,7 +149,6 @@ func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shi } } process, err := newInit( - ctx, s.config.Path, s.config.WorkDir, s.config.RuntimeRoot, @@ -158,6 +156,7 @@ func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shi s.config.RunscConfig, s.platform, config, + r.Options, ) if err := process.Create(ctx, config); err != nil { return nil, errdefs.ToGRPC(err) @@ -533,14 +532,14 @@ func getTopic(ctx context.Context, e interface{}) string { return runtime.TaskUnknownTopic } -func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) { - var options runctypes.CreateOptions - if r.Options != nil { - v, err := typeurl.UnmarshalAny(r.Options) +func newInit(path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig, options *types.Any) (*proc.Init, error) { + var opts runctypes.CreateOptions + if options != nil { + v, err := typeurl.UnmarshalAny(options) if err != nil { return nil, err } - options = *v.(*runctypes.CreateOptions) + opts = *v.(*runctypes.CreateOptions) } spec, err := utils.ReadSpec(r.Bundle) @@ -551,7 +550,7 @@ func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, return nil, fmt.Errorf("update volume annotations: %w", err) } - runsc.FormatLogPath(r.ID, config) + runsc.FormatRunscLogPath(r.ID, config) rootfs := filepath.Join(path, "rootfs") runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config) p := proc.New(r.ID, runtime, stdio.Stdio{ @@ -564,8 +563,8 @@ func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, p.Platform = platform p.Rootfs = rootfs p.WorkDir = workDir - p.IoUID = int(options.IoUid) - p.IoGID = int(options.IoGid) + p.IoUID = int(opts.IoUid) + p.IoGID = int(opts.IoGid) p.Sandbox = utils.IsSandbox(spec) p.UserLog = utils.UserLogPath(spec) p.Monitor = reaper.Default diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD index 7e0a114a0..f37fefddc 100644 --- a/pkg/shim/v2/BUILD +++ b/pkg/shim/v2/BUILD @@ -7,15 +7,17 @@ go_library( srcs = [ "api.go", "epoll.go", + "options.go", "service.go", "service_linux.go", + "state.go", ], visibility = ["//shim:__subpackages__"], deps = [ + "//pkg/cleanup", "//pkg/shim/runsc", "//pkg/shim/v1/proc", "//pkg/shim/v1/utils", - "//pkg/shim/v2/options", "//pkg/shim/v2/runtimeoptions", "//runsc/specutils", "@com_github_burntsushi_toml//:go_default_library", @@ -38,6 +40,7 @@ go_library( "@com_github_containerd_fifo//:go_default_library", "@com_github_containerd_typeurl//:go_default_library", "@com_github_gogo_protobuf//types:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/shim/v2/options.go b/pkg/shim/v2/options.go new file mode 100644 index 000000000..9db33fd1f --- /dev/null +++ b/pkg/shim/v2/options.go @@ -0,0 +1,50 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +const optionsType = "io.containerd.runsc.v1.options" + +// options is runtime options for io.containerd.runsc.v1. +type options struct { + // ShimCgroup is the cgroup the shim should be in. + ShimCgroup string `toml:"shim_cgroup" json:"shimCgroup"` + + // IoUID is the I/O's pipes uid. + IoUID uint32 `toml:"io_uid" json:"ioUid"` + + // IoGID is the I/O's pipes gid. + IoGID uint32 `toml:"io_gid" json:"ioGid"` + + // BinaryName is the binary name of the runsc binary. + BinaryName string `toml:"binary_name" json:"binaryName"` + + // Root is the runsc root directory. + Root string `toml:"root" json:"root"` + + // LogLevel sets the logging level. Some of the possible values are: debug, + // info, warning. + // + // This configuration only applies when the shim is running as a service. + LogLevel string `toml:"log_level" json:"logLevel"` + + // LogPath is the path to log directory. %ID% tags inside the string are + // replaced with the container ID. + // + // This configuration only applies when the shim is running as a service. + LogPath string `toml:"log_path" json:"logPath"` + + // RunscConfig is a key/value map of all runsc flags. + RunscConfig map[string]string `toml:"runsc_config" json:"runscConfig"` +} diff --git a/pkg/shim/v2/options/BUILD b/pkg/shim/v2/options/BUILD deleted file mode 100644 index ca212e874..000000000 --- a/pkg/shim/v2/options/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "options", - srcs = [ - "options.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go deleted file mode 100644 index de09f2f79..000000000 --- a/pkg/shim/v2/options/options.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package options - -const OptionType = "io.containerd.runsc.v1.options" - -// Options is runtime options for io.containerd.runsc.v1. -type Options struct { - // ShimCgroup is the cgroup the shim should be in. - ShimCgroup string `toml:"shim_cgroup"` - // IoUid is the I/O's pipes uid. - IoUid uint32 `toml:"io_uid"` - // IoUid is the I/O's pipes gid. - IoGid uint32 `toml:"io_gid"` - // BinaryName is the binary name of the runsc binary. - BinaryName string `toml:"binary_name"` - // Root is the runsc root directory. - Root string `toml:"root"` - // RunscConfig is a key/value map of all runsc flags. - RunscConfig map[string]string `toml:"runsc_config"` -} diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go index 1534152fc..2e39d2c4a 100644 --- a/pkg/shim/v2/service.go +++ b/pkg/shim/v2/service.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package v2 implements Containerd Shim v2 interface. package v2 import ( "context" "fmt" - "io/ioutil" + "io" "os" "os/exec" "path/filepath" @@ -43,12 +44,13 @@ import ( "github.com/containerd/containerd/sys/reaper" "github.com/containerd/typeurl" "github.com/gogo/protobuf/types" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/shim/runsc" "gvisor.dev/gvisor/pkg/shim/v1/proc" "gvisor.dev/gvisor/pkg/shim/v1/utils" - "gvisor.dev/gvisor/pkg/shim/v2/options" "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions" "gvisor.dev/gvisor/runsc/specutils" ) @@ -71,49 +73,88 @@ const configFile = "config.toml" // New returns a new shim service that can be used via GRPC. func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) { + log.L.Debugf("service.New, id: %s", id) + + var opts shim.Opts + if ctxOpts := ctx.Value(shim.OptsKey{}); ctxOpts != nil { + opts = ctxOpts.(shim.Opts) + } + ep, err := newOOMEpoller(publisher) if err != nil { return nil, err } go ep.run(ctx) s := &service{ - id: id, - context: ctx, - processes: make(map[string]process.Process), - events: make(chan interface{}, 128), - ec: proc.ExitCh, - oomPoller: ep, - cancel: cancel, - } - go s.processExits() - runsc.Monitor = reaper.Default + id: id, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + oomPoller: ep, + cancel: cancel, + genericOptions: opts, + } + go s.processExits(ctx) + runsc.Monitor = &runsc.LogMonitor{Next: reaper.Default} if err := s.initPlatform(); err != nil { cancel() return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) } - go s.forward(publisher) + go s.forward(ctx, publisher) return s, nil } -// service is the shim implementation of a remote shim over GRPC. +// service is the shim implementation of a remote shim over GRPC. It runs in 2 +// different modes: +// 1. Service: process runs for the life time of the container and receives +// calls described in shimapi.TaskService interface. +// 2. Tool: process is short lived and runs only to perform the requested +// operations and then exits. It implements the direct functions in +// shim.Shim interface. +// +// When the service is running, it saves a json file with state information so +// that commands sent to the tool can load the state and perform the operation. type service struct { mu sync.Mutex - context context.Context - task process.Process + // id is the container ID. + id string + + // bundle is a path provided by the caller on container creation. Store + // because it's needed in commands that don't receive bundle in the request. + bundle string + + // task is the main process that is running the container. + task *proc.Init + + // processes maps ExecId to processes running through exec. processes map[string]process.Process - events chan interface{} - platform stdio.Platform - opts options.Options - ec chan proc.Exit + + events chan interface{} + + // platform handles operations related to the console. + platform stdio.Platform + + // genericOptions are options that come from the shim interface and are common + // to all shims. + genericOptions shim.Opts + + // opts are configuration options specific for this shim. + opts options + + // ex gets notified whenever the container init process or an exec'd process + // exits from inside the sandbox. + ec chan proc.Exit + + // oomPoller monitors the sandbox's cgroup for OOM notifications. oomPoller *epoller - id string - bundle string + // cancel is a function that needs to be called before the shim stops. The + // function is provided by the caller to New(). cancel func() } -func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) { +func (s *service) newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) { ns, err := namespaces.NamespaceRequired(ctx) if err != nil { return nil, err @@ -131,6 +172,9 @@ func newCommand(ctx context.Context, containerdBinary, containerdAddress string) "-address", containerdAddress, "-publish-binary", containerdBinary, } + if s.genericOptions.Debug { + args = append(args, "-debug") + } cmd := exec.Command(self, args...) cmd.Dir = cwd cmd.Env = append(os.Environ(), "GOMAXPROCS=2") @@ -141,7 +185,9 @@ func newCommand(ctx context.Context, containerdBinary, containerdAddress string) } func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) { - cmd, err := newCommand(ctx, containerdBinary, containerdAddress) + log.L.Debugf("StartShim, id: %s, binary: %q, address: %q", id, containerdBinary, containerdAddress) + + cmd, err := s.newCommand(ctx, containerdBinary, containerdAddress) if err != nil { return "", err } @@ -162,14 +208,15 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container cmd.ExtraFiles = append(cmd.ExtraFiles, f) + log.L.Debugf("Executing: %q %s", cmd.Path, cmd.Args) if err := cmd.Start(); err != nil { return "", err } - defer func() { - if err != nil { - cmd.Process.Kill() - } - }() + cu := cleanup.Make(func() { + cmd.Process.Kill() + }) + defer cu.Clean() + // make sure to wait after start go cmd.Wait() if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil { @@ -181,10 +228,15 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container if err := shim.SetScore(cmd.Process.Pid); err != nil { return "", fmt.Errorf("failed to set OOM Score on shim: %w", err) } + cu.Release() return address, nil } +// Cleanup is called from another process (need to reload state) to stop the +// container and undo all operations done in Create(). func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) { + log.L.Debugf("Cleanup") + path, err := os.Getwd() if err != nil { return nil, err @@ -193,18 +245,19 @@ func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) if err != nil { return nil, err } - runtime, err := s.readRuntime(path) - if err != nil { + var st state + if err := st.load(path); err != nil { return nil, err } - r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + r := proc.NewRunsc(s.opts.Root, path, ns, st.Options.BinaryName, nil) + if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{ Force: true, }); err != nil { - log.L.Printf("failed to remove runc container: %v", err) + log.L.Infof("failed to remove runc container: %v", err) } - if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil { - log.L.Printf("failed to cleanup rootfs mount: %v", err) + if err := mount.UnmountAll(st.Rootfs, 0); err != nil { + log.L.Infof("failed to cleanup rootfs mount: %v", err) } return &taskAPI.DeleteResponse{ ExitedAt: time.Now(), @@ -212,31 +265,24 @@ func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) }, nil } -func (s *service) readRuntime(path string) (string, error) { - data, err := ioutil.ReadFile(filepath.Join(path, "runtime")) - if err != nil { - return "", err - } - return string(data), nil -} - -func (s *service) writeRuntime(path, runtime string) error { - return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600) -} - // Create creates a new initial process and container with the underlying OCI // runtime. -func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) { +func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*taskAPI.CreateTaskResponse, error) { + log.L.Debugf("Create, id: %s, bundle: %q", r.ID, r.Bundle) + s.mu.Lock() defer s.mu.Unlock() + // Save the main task id and bundle to the shim for additional requests. + s.id = r.ID + s.bundle = r.Bundle + ns, err := namespaces.NamespaceRequired(ctx) if err != nil { return nil, fmt.Errorf("create namespace: %w", err) } // Read from root for now. - var opts options.Options if r.Options != nil { v, err := typeurl.UnmarshalAny(r.Options) if err != nil { @@ -245,16 +291,16 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * var path string switch o := v.(type) { case *runctypes.CreateOptions: // containerd 1.2.x - opts.IoUid = o.IoUid - opts.IoGid = o.IoGid - opts.ShimCgroup = o.ShimCgroup + s.opts.IoUID = o.IoUid + s.opts.IoGID = o.IoGid + s.opts.ShimCgroup = o.ShimCgroup case *runctypes.RuncOptions: // containerd 1.2.x root := proc.RunscRoot if o.RuntimeRoot != "" { root = o.RuntimeRoot } - opts.BinaryName = o.Runtime + s.opts.BinaryName = o.Runtime path = filepath.Join(root, configFile) if _, err := os.Stat(path); err != nil { @@ -268,7 +314,7 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * if o.ConfigPath == "" { break } - if o.TypeUrl != options.OptionType { + if o.TypeUrl != optionsType { return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl) } path = o.ConfigPath @@ -276,12 +322,61 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl) } if path != "" { - if _, err = toml.DecodeFile(path, &opts); err != nil { + if _, err = toml.DecodeFile(path, &s.opts); err != nil { return nil, fmt.Errorf("decode config file %q: %w", path, err) } } } + if len(s.opts.LogLevel) != 0 { + lvl, err := logrus.ParseLevel(s.opts.LogLevel) + if err != nil { + return nil, err + } + logrus.SetLevel(lvl) + } + if len(s.opts.LogPath) != 0 { + logPath := runsc.FormatShimLogPath(s.opts.LogPath, s.id) + if err := os.MkdirAll(filepath.Dir(logPath), 0777); err != nil { + return nil, fmt.Errorf("failed to create log dir: %w", err) + } + logFile, err := os.Create(logPath) + if err != nil { + return nil, fmt.Errorf("failed to create log file: %w", err) + } + log.L.Debugf("Starting mirror log at %q", logPath) + std := logrus.StandardLogger() + std.SetOutput(io.MultiWriter(std.Out, logFile)) + + log.L.Debugf("Create shim") + log.L.Debugf("***************************") + log.L.Debugf("Args: %s", os.Args) + log.L.Debugf("PID: %d", os.Getpid()) + log.L.Debugf("ID: %s", s.id) + log.L.Debugf("Options: %+v", s.opts) + log.L.Debugf("Bundle: %s", r.Bundle) + log.L.Debugf("Terminal: %t", r.Terminal) + log.L.Debugf("stdin: %s", r.Stdin) + log.L.Debugf("stdout: %s", r.Stdout) + log.L.Debugf("stderr: %s", r.Stderr) + log.L.Debugf("***************************") + } + + // Save state before any action is taken to ensure Cleanup() will have all + // the information it needs to undo the operations. + st := state{ + Rootfs: filepath.Join(r.Bundle, "rootfs"), + Options: s.opts, + } + if err := st.save(r.Bundle); err != nil { + return nil, err + } + + if err := os.Mkdir(st.Rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + // Convert from types.Mount to proc.Mount. var mounts []proc.Mount for _, m := range r.Rootfs { mounts = append(mounts, proc.Mount{ @@ -292,62 +387,41 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * }) } - rootfs := filepath.Join(r.Bundle, "rootfs") - if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { - return nil, err + // Cleans up all mounts in case of failure. + cu := cleanup.Make(func() { + if err := mount.UnmountAll(st.Rootfs, 0); err != nil { + log.L.Infof("failed to cleanup rootfs mount: %v", err) + } + }) + defer cu.Clean() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(st.Rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } } config := &proc.CreateConfig{ ID: r.ID, Bundle: r.Bundle, - Runtime: opts.BinaryName, + Runtime: s.opts.BinaryName, Rootfs: mounts, Terminal: r.Terminal, Stdin: r.Stdin, Stdout: r.Stdout, Stderr: r.Stderr, - Options: r.Options, - } - if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil { - return nil, err - } - defer func() { - if err != nil { - if err := mount.UnmountAll(rootfs, 0); err != nil { - log.L.Printf("failed to cleanup rootfs mount: %v", err) - } - } - }() - for _, rm := range mounts { - m := &mount.Mount{ - Type: rm.Type, - Source: rm.Source, - Options: rm.Options, - } - if err := m.Mount(rootfs); err != nil { - return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) - } } - process, err := newInit( - ctx, - r.Bundle, - filepath.Join(r.Bundle, "work"), - ns, - s.platform, - config, - &opts, - rootfs, - ) + process, err := newInit(r.Bundle, filepath.Join(r.Bundle, "work"), ns, s.platform, config, &s.opts, st.Rootfs) if err != nil { return nil, errdefs.ToGRPC(err) } if err := process.Create(ctx, config); err != nil { return nil, errdefs.ToGRPC(err) } - // Save the main task id and bundle to the shim for additional - // requests. - s.id = r.ID - s.bundle = r.Bundle // Set up OOM notification on the sandbox's cgroup. This is done on // sandbox create since the sandbox process will be created here. @@ -361,16 +435,19 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * return nil, fmt.Errorf("add cg to OOM monitor: %w", err) } } + + // Success + cu.Release() s.task = process - s.opts = opts return &taskAPI.CreateTaskResponse{ Pid: uint32(process.Pid()), }, nil - } // Start starts a process. func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) { + log.L.Debugf("Start, id: %s, execID: %s", r.ID, r.ExecID) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -387,6 +464,8 @@ func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI. // Delete deletes the initial process and container. func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) { + log.L.Debugf("Delete, id: %s, execID: %s", r.ID, r.ExecID) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -397,13 +476,11 @@ func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAP if err := p.Delete(ctx); err != nil { return nil, err } - isTask := r.ExecID == "" - if !isTask { + if len(r.ExecID) != 0 { s.mu.Lock() delete(s.processes, r.ExecID) s.mu.Unlock() - } - if isTask && s.platform != nil { + } else if s.platform != nil { s.platform.Close() } return &taskAPI.DeleteResponse{ @@ -415,17 +492,18 @@ func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAP // Exec spawns an additional process inside the container. func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) { + log.L.Debugf("Exec, id: %s, execID: %s", r.ID, r.ExecID) + s.mu.Lock() p := s.processes[r.ExecID] s.mu.Unlock() if p != nil { return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) } - p = s.task - if p == nil { + if s.task == nil { return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{ + process, err := s.task.Exec(ctx, s.bundle, &proc.ExecConfig{ ID: r.ExecID, Terminal: r.Terminal, Stdin: r.Stdin, @@ -444,6 +522,8 @@ func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*typ // ResizePty resizes the terminal of a process. func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) { + log.L.Debugf("ResizePty, id: %s, execID: %s, dimension: %dx%d", r.ID, r.ExecID, r.Height, r.Width) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -460,6 +540,8 @@ func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (* // State returns runtime state information for a process. func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) { + log.L.Debugf("State, id: %s, execID: %s", r.ID, r.ExecID) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -494,16 +576,20 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI. // Pause the container. func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) { + log.L.Debugf("Pause, id: %s", r.ID) return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) } // Resume the container. func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) { + log.L.Debugf("Resume, id: %s", r.ID) return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) } // Kill a process with the provided signal. func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) { + log.L.Debugf("Kill, id: %s, execID: %s, signal: %d, all: %t", r.ID, r.ExecID, r.Signal, r.All) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -519,6 +605,8 @@ func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empt // Pids returns all pids inside the container. func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) { + log.L.Debugf("Pids, id: %s", r.ID) + pids, err := s.getContainerPids(ctx, r.ID) if err != nil { return nil, errdefs.ToGRPC(err) @@ -550,6 +638,8 @@ func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.Pi // CloseIO closes the I/O context of a process. func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) { + log.L.Debugf("CloseIO, id: %s, execID: %s, stdin: %t", r.ID, r.ExecID, r.Stdin) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -564,11 +654,14 @@ func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*type // Checkpoint checkpoints the container. func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) { + log.L.Debugf("Checkpoint, id: %s", r.ID) return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) } // Connect returns shim information such as the shim's pid. func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) { + log.L.Debugf("Connect, id: %s", r.ID) + var pid int if s.task != nil { pid = s.task.Pid() @@ -580,27 +673,21 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task } func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) { + log.L.Debugf("Shutdown, id: %s", r.ID) s.cancel() os.Exit(0) return empty, nil } func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) { - path, err := os.Getwd() - if err != nil { - return nil, err - } - ns, err := namespaces.NamespaceRequired(ctx) - if err != nil { - return nil, err - } - runtime, err := s.readRuntime(path) - if err != nil { - return nil, err + log.L.Debugf("Stats, id: %s", r.ID) + if s.task == nil { + log.L.Debugf("Stats error, id: %s: container not created", r.ID) + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) - stats, err := rs.Stats(ctx, s.id) + stats, err := s.task.Runtime().Stats(ctx, s.id) if err != nil { + log.L.Debugf("Stats error, id: %s: %v", r.ID, err) return nil, err } @@ -611,7 +698,7 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI. // as runc. // // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81 - data, err := typeurl.MarshalAny(&cgroups.Metrics{ + metrics := &cgroups.Metrics{ CPU: &cgroups.CPUStat{ Usage: &cgroups.CPUUsage{ Total: stats.Cpu.Usage.Total, @@ -656,10 +743,13 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI. Current: stats.Pids.Current, Limit: stats.Pids.Limit, }, - }) + } + data, err := typeurl.MarshalAny(metrics) if err != nil { + log.L.Debugf("Stats error, id: %s: %v", r.ID, err) return nil, err } + log.L.Debugf("Stats success, id: %s: %+v", r.ID, data) return &taskAPI.StatsResponse{ Stats: data, }, nil @@ -672,6 +762,8 @@ func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*ty // Wait waits for a process to exit. func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) { + log.L.Debugf("Wait, id: %s, execID: %s", r.ID, r.ExecID) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -687,21 +779,22 @@ func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.Wa }, nil } -func (s *service) processExits() { +func (s *service) processExits(ctx context.Context) { for e := range s.ec { - s.checkProcesses(e) + s.checkProcesses(ctx, e) } } -func (s *service) checkProcesses(e proc.Exit) { +func (s *service) checkProcesses(ctx context.Context, e proc.Exit) { // TODO(random-liu): Add `shouldKillAll` logic if container pid // namespace is supported. for _, p := range s.allProcesses() { if p.ID() == e.ID { if ip, ok := p.(*proc.Init); ok { // Ensure all children are killed. - if err := ip.KillAll(s.context); err != nil { - log.G(s.context).WithError(err).WithField("id", ip.ID()). + log.L.Debugf("Container init process exited, killing all container processes") + if err := ip.KillAll(ctx); err != nil { + log.G(ctx).WithError(err).WithField("id", ip.ID()). Error("failed to kill init's children") } } @@ -737,7 +830,7 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er if p == nil { return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition) } - ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + ps, err := p.Runtime().Ps(ctx, id) if err != nil { return nil, err } @@ -748,9 +841,9 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er return pids, nil } -func (s *service) forward(publisher shim.Publisher) { +func (s *service) forward(ctx context.Context, publisher shim.Publisher) { for e := range s.events { - ctx, cancel := context.WithTimeout(s.context, 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) err := publisher.Publish(ctx, getTopic(e), e) cancel() if err != nil { @@ -790,12 +883,12 @@ func getTopic(e interface{}) string { case *events.TaskExecStarted: return runtime.TaskExecStartedEventTopic default: - log.L.Printf("no topic for type %#v", e) + log.L.Infof("no topic for type %#v", e) } return runtime.TaskUnknownTopic } -func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) { +func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options, rootfs string) (*proc.Init, error) { spec, err := utils.ReadSpec(r.Bundle) if err != nil { return nil, fmt.Errorf("read oci spec: %w", err) @@ -803,7 +896,7 @@ func newInit(ctx context.Context, path, workDir, namespace string, platform stdi if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { return nil, fmt.Errorf("update volume annotations: %w", err) } - runsc.FormatLogPath(r.ID, options.RunscConfig) + runsc.FormatRunscLogPath(r.ID, options.RunscConfig) runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig) p := proc.New(r.ID, runtime, stdio.Stdio{ Stdin: r.Stdin, @@ -815,8 +908,8 @@ func newInit(ctx context.Context, path, workDir, namespace string, platform stdi p.Platform = platform p.Rootfs = rootfs p.WorkDir = workDir - p.IoUID = int(options.IoUid) - p.IoGID = int(options.IoGid) + p.IoUID = int(options.IoUID) + p.IoGID = int(options.IoGID) p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox p.UserLog = utils.UserLogPath(spec) p.Monitor = reaper.Default diff --git a/pkg/shim/v2/state.go b/pkg/shim/v2/state.go new file mode 100644 index 000000000..1f4be33d3 --- /dev/null +++ b/pkg/shim/v2/state.go @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "encoding/json" + "io/ioutil" + "path/filepath" +) + +const filename = "state.json" + +// state holds information needed between shim invocations. +type state struct { + // Rootfs is the full path to the location rootfs was mounted. + Rootfs string `json:"rootfs"` + + // Options is the configuration loaded from config.toml. + Options options `json:"options"` +} + +func (s state) load(path string) error { + data, err := ioutil.ReadFile(filepath.Join(path, filename)) + if err != nil { + return err + } + return json.Unmarshal(data, &s) +} + +func (s state) save(path string) error { + data, err := json.Marshal(&s) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(path, filename), data, 0644) +} diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 454e07662..27f96a3ac 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ + "socketops.go", "tcpip.go", "time_unsafe.go", "timer.go", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 8868cf4e3..81f762e10 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -904,6 +904,12 @@ func ICMPv4Payload(want []byte) TransportChecker { t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } payload := icmpv4.Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(payload) == 0 { + return + } + if diff := cmp.Diff(want, payload); diff != "" { t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } @@ -994,6 +1000,12 @@ func ICMPv6Payload(want []byte) TransportChecker { t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } payload := icmpv6.Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(payload) == 0 { + return + } + if diff := cmp.Diff(want, payload); diff != "" { t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 7e32b31b4..91fe7b6a5 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -89,7 +89,17 @@ type IPv4Fields struct { // DstAddr is the "destination ip address" of an IPv4 packet. DstAddr tcpip.Address - // Options is between 0 and 40 bytes or nil if empty. + // Options must be 40 bytes or less as they must fit along with the + // rest of the IPv4 header into the maximum size describable in the + // IHL field. RFC 791 section 3.1 says: + // IHL: 4 bits + // + // Internet Header Length is the length of the internet header in 32 + // bit words, and thus points to the beginning of the data. Note that + // the minimum value for a correct header is 5. + // + // That leaves ten 32 bit (4 byte) fields for options. An attempt to encode + // more will fail. Options IPv4Options } @@ -275,22 +285,19 @@ func (b IPv4) DestinationAddress() tcpip.Address { // IPv4Options is a buffer that holds all the raw IP options. type IPv4Options []byte -// AllocationSize implements stack.NetOptions. +// SizeWithPadding implements stack.NetOptions. // It reports the size to allocate for the Options. RFC 791 page 23 (end of // section 3.1) says of the padding at the end of the options: // The internet header padding is used to ensure that the internet // header ends on a 32 bit boundary. -func (o IPv4Options) AllocationSize() int { +func (o IPv4Options) SizeWithPadding() int { return (len(o) + IPv4IHLStride - 1) & ^(IPv4IHLStride - 1) } -// Options returns a buffer holding the options or nil. +// Options returns a buffer holding the options. func (b IPv4) Options() IPv4Options { hdrLen := b.HeaderLength() - if hdrLen > IPv4MinimumSize { - return IPv4Options(b[options:hdrLen:hdrLen]) - } - return nil + return IPv4Options(b[options:hdrLen:hdrLen]) } // TransportProtocol implements Network.TransportProtocol. @@ -368,14 +375,20 @@ func (b IPv4) Encode(i *IPv4Fields) { // worth a bit of optimisation here to keep the copy out of the fast path. hdrLen := IPv4MinimumSize if len(i.Options) != 0 { - // AllocationSize is always >= len(i.Options). - aLen := i.Options.AllocationSize() + // SizeWithPadding is always >= len(i.Options). + aLen := i.Options.SizeWithPadding() hdrLen += aLen if hdrLen > len(b) { panic(fmt.Sprintf("encode received %d bytes, wanted >= %d", len(b), hdrLen)) } - if aLen != copy(b[options:], i.Options) { - _ = copy(b[options+len(i.Options):options+aLen], []byte{0, 0, 0, 0}) + opts := b[options:] + // This avoids bounds checks on the next line(s) which would happen even + // if there's no work to do. + if n := copy(opts, i.Options); n != aLen { + padding := opts[n:][:aLen-n] + for i := range padding { + padding[i] = 0 + } } } b.SetHeaderLength(uint8(hdrLen)) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 4e7e5f76a..55d09355a 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -54,7 +54,7 @@ type IPv6Fields struct { // NextHeader is the "next header" field of an IPv6 packet. NextHeader uint8 - // HopLimit is the "hop limit" field of an IPv6 packet. + // HopLimit is the "Hop Limit" field of an IPv6 packet. HopLimit uint8 // SrcAddr is the "source ip address" of an IPv6 packet. @@ -171,7 +171,7 @@ func (b IPv6) PayloadLength() uint16 { return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) } -// HopLimit returns the value of the "hop limit" field of the ipv6 header. +// HopLimit returns the value of the "Hop Limit" field of the ipv6 header. func (b IPv6) HopLimit() uint8 { return b[hopLimit] } @@ -236,6 +236,11 @@ func (b IPv6) SetDestinationAddress(addr tcpip.Address) { copy(b[v6DstAddr:][:IPv6AddressSize], addr) } +// SetHopLimit sets the value of the "Hop Limit" field. +func (b IPv6) SetHopLimit(v uint8) { + b[hopLimit] = v +} + // SetNextHeader sets the value of the "next header" field of the ipv6 header. func (b IPv6) SetNextHeader(v uint8) { b[IPv6NextHeaderOffset] = v diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 5ca75c834..2042f214a 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -109,6 +109,9 @@ traverseExtensions: fragOffset = extHdr.FragmentOffset() fragMore = extHdr.More() } + rawPayload := it.AsRawHeader(true /* consume */) + extensionsSize = dataClone.Size() - rawPayload.Buf.Size() + break traverseExtensions case header.IPv6RawPayloadHeader: // We've found the payload after any extensions. diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 39ca774ef..973f06cbc 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -9,7 +9,6 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c95aef63c..a7f5f4979 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -271,21 +270,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := PacketInfo{ - Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }), - Proto: 0, - GSO: nil, - } - - e.q.Write(p) - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 975309fc8..fc620c7d5 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -558,11 +558,6 @@ func viewsEqual(vs1, vs2 []buffer.View) bool { return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) -} - // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 38aa694e4..edca57e4e 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -96,23 +96,6 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - // There should be an ethernet header at the beginning of vv. - hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. - return tcpip.ErrBadAddress - } - linkHeader := header.Ethernet(hdr) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) - - return nil -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*endpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareLoopback diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index e7493e5c5..cbda59775 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 56a611825..22e79ce3a 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -17,7 +17,6 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -106,13 +105,6 @@ func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protoco return tcpip.ErrNoRoute } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // WriteRawPacket doesn't get a route or network address, so there's - // nowhere to write this. - return tcpip.ErrNoRoute -} - // InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD index 2cdb23475..00b42b924 100644 --- a/pkg/tcpip/link/nested/BUILD +++ b/pkg/tcpip/link/nested/BUILD @@ -11,7 +11,6 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index d40de54df..0ee54c3d5 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -19,7 +19,6 @@ package nested import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -123,11 +122,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return e.child.WritePackets(r, gso, pkts, protocol) } -// WriteRawPacket implements stack.LinkEndpoint. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - return e.child.WriteRawPacket(vv) -} - // Wait implements stack.LinkEndpoint. func (e *Endpoint) Wait() { e.child.Wait() diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 523b0d24b..71fcb73e1 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -67,11 +67,6 @@ func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint. -func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - panic("not implemented") -} - // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 1d0079bd6..5bea598eb 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -13,7 +13,6 @@ go_library( "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index fc1e34fc7..9b41d60d5 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -197,13 +196,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB return enqueued, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // TODO(gvisor.dev/issue/3267): Queue these packets as well once - // WriteRawPacket takes PacketBuffer instead of VectorisedView. - return e.lower.WriteRawPacket(vv) -} - // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 7fb8a6c49..a1e7018c8 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -224,21 +224,6 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - views := vv.Views() - // Transmit the packet. - e.mu.Lock() - ok := e.tx.transmit(views...) - e.mu.Unlock() - - if !ok { - return tcpip.ErrWouldBlock - } - - return nil -} - // dispatchLoop reads packets from the rx queue in a loop and dispatches them // to the network stack. func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index b3e8c4b92..8d9a91020 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -53,16 +53,35 @@ type endpoint struct { nested.Endpoint writer io.Writer maxPCAPLen uint32 + logPrefix string } var _ stack.GSOEndpoint = (*endpoint)(nil) var _ stack.LinkEndpoint = (*endpoint)(nil) var _ stack.NetworkDispatcher = (*endpoint)(nil) +type direction int + +const ( + directionSend = iota + directionRecv +) + // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - sniffer := &endpoint{} + return NewWithPrefix(lower, "") +} + +// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around +// another endpoint and logs packets prefixed with logPrefix as they traverse +// the endpoint. +// +// logPrefix is prepended to the log line without any separators. +// E.g. logPrefix = "NIC:en0/" will produce log lines like +// "NIC:en0/send udp [...]". +func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint { + sniffer := &endpoint{logPrefix: logPrefix} sniffer.Endpoint.Init(lower, sniffer) return sniffer } @@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, pkt) + e.dumpPacket(directionRecv, nil, protocol, pkt) e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } -func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + logPacket(e.logPrefix, dir, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Size() @@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, pkt) + e.dumpPacket(directionSend, gso, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } @@ -178,20 +197,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // forwards the request to the lower endpoint. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.dumpPacket("send", gso, protocol, pkt) + e.dumpPacket(directionSend, gso, protocol, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, protocol) } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return e.Endpoint.WriteRawPacket(vv) -} - -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -201,6 +212,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var fragmentOffset uint16 var moreFragments bool + var directionPrefix string + switch dir { + case directionSend: + directionPrefix = "send" + case directionRecv: + directionPrefix = "recv" + default: + panic(fmt.Sprintf("unrecognized direction: %d", dir)) + } + // Clone the packet buffer to not modify the original. // // We don't clone the original packet buffer so that the new packet buffer @@ -248,15 +269,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( - "%s arp %s (%s) -> %s (%s) valid:%t", + "%s%s arp %s (%s) -> %s (%s) valid:%t", prefix, + directionPrefix, tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), arp.IsValid(), ) return default: - log.Infof("%s unknown network protocol", prefix) + log.Infof("%s%s unknown network protocol", prefix, directionPrefix) return } @@ -300,7 +322,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P icmpType = "info reply" } } - log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.ICMPv6ProtocolNumber: @@ -335,7 +357,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6RedirectMsg: icmpType = "redirect message" } - log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.UDPProtocolNumber: @@ -391,7 +413,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P } default: - log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto) + log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto) return } @@ -399,5 +421,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P details += fmt.Sprintf(" gso: %+v", gso) } - log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) + log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 4c14f55d3..9a76bdba7 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -76,29 +76,13 @@ func (d *Device) Release(ctx context.Context) { } } -// NICID returns the NIC ID of the device. -// -// Must only be called after the device has been attached to an endpoint. -func (d *Device) NICID() tcpip.NICID { - d.mu.RLock() - defer d.mu.RUnlock() - - if d.endpoint == nil { - panic("called NICID on a device that has not been attached") - } - - return d.endpoint.nicID -} - // SetIff services TUNSETIFF ioctl(2) request. -// -// Returns true if a new NIC was created; false if an existing one was attached. -func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) { +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { d.mu.Lock() defer d.mu.Unlock() if d.endpoint != nil { - return false, syserror.EINVAL + return syserror.EINVAL } // Input validations. @@ -106,7 +90,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) isTap := flags&linux.IFF_TAP != 0 supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { - return false, syserror.EINVAL + return syserror.EINVAL } prefix := "tun" @@ -119,18 +103,18 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) linkCaps |= stack.CapabilityResolutionRequired } - endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps) + endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return false, syserror.EINVAL + return syserror.EINVAL } d.endpoint = endpoint d.notifyHandle = d.endpoint.AddNotify(d) d.flags = flags - return created, nil + return nil } -func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) { +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { for { // 1. Try to attach to an existing NIC. if name != "" { @@ -138,13 +122,13 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE endpoint, ok := linkEP.(*tunEndpoint) if !ok { // Not a NIC created by tun device. - return nil, false, syserror.EOPNOTSUPP + return nil, syserror.EOPNOTSUPP } if !endpoint.TryIncRef() { // Race detected: NIC got deleted in between. continue } - return endpoint, false, nil + return endpoint, nil } } @@ -167,12 +151,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE }) switch err { case nil: - return endpoint, true, nil + return endpoint, nil case tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: - return nil, false, syserror.EINVAL + return nil, syserror.EINVAL } } } diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index ee84c3d96..9b4602c1b 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -11,7 +11,6 @@ go_library( deps = [ "//pkg/gate", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], @@ -25,7 +24,6 @@ go_test( library = ":waitable", deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index b152a0f26..cf0077f43 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -24,7 +24,6 @@ package waitable import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -132,17 +131,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, err } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if !e.writeGate.Enter() { - return nil - } - - err := e.lower.WriteRawPacket(vv) - e.writeGate.Leave() - return err -} - // WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, // and waits for inflight ones to finish before returning. func (e *Endpoint) WaitWrite() { diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 94827fc56..cf7fb5126 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -18,7 +18,6 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -81,11 +80,6 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack. return pkts.Len(), nil } -func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - e.writeCount++ - return nil -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { panic("unimplemented") diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index c118a2929..b38aff0b8 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -14,6 +14,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 33a4a0720..3d5c0d270 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -31,17 +31,15 @@ import ( const ( // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber - - // ProtocolAddress is the address expected by the ARP endpoint. - ProtocolAddress = tcpip.Address("arp") ) -var _ stack.AddressableEndpoint = (*endpoint)(nil) +// ARP endpoints need to implement stack.NetworkEndpoint because the stack +// considers the layer above the link-layer a network layer; the only +// facility provided by the stack to deliver packets to a layer above +// the link-layer is via stack.NetworkEndpoint.HandlePacket. var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { - stack.AddressableEndpointState - protocol *protocol // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. @@ -87,7 +85,7 @@ func (e *endpoint) Disable() { } // DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. -func (e *endpoint) DefaultTTL() uint8 { +func (*endpoint) DefaultTTL() uint8 { return 0 } @@ -100,25 +98,23 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.AddressableEndpointState.Cleanup() -} +func (*endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. -func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { +func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return ProtocolNumber } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -216,9 +212,8 @@ func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber func (p *protocol) MinimumPacketSize() int { return header.ARPSize } func (p *protocol) DefaultPrefixLen() int { return 0 } -func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - h := header.ARP(v) - return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress +func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { + return "", "" } func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { @@ -228,7 +223,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L linkAddrCache: linkAddrCache, nud: nud, } - e.AddressableEndpointState.Init(e) return e } @@ -311,10 +305,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } // NewProtocol returns an ARP network protocol. -// -// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" -// address must be added to every NIC that should respond to ARP requests. See -// ProtocolAddress for more details. func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 087ee9c66..f462524c9 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -200,9 +200,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { t.Fatalf("AddAddress for ipv4 failed: %v", err) } } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("AddAddress for arp failed: %v", err) - } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -439,6 +436,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ NetProto: protocol, diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 47fb63290..d8e4a3b54 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -47,6 +47,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 936601287..c75ca7d71 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -71,16 +71,25 @@ type FragmentID struct { // Fragmentation is the main structure that other modules // of the stack should use to implement IP Fragmentation. type Fragmentation struct { - mu sync.Mutex - highLimit int - lowLimit int - reassemblers map[FragmentID]*reassembler - rList reassemblerList - size int - timeout time.Duration - blockSize uint16 - clock tcpip.Clock - releaseJob *tcpip.Job + mu sync.Mutex + highLimit int + lowLimit int + reassemblers map[FragmentID]*reassembler + rList reassemblerList + size int + timeout time.Duration + blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job + timeoutHandler TimeoutHandler +} + +// TimeoutHandler is consulted if a packet reassembly has timed out. +type TimeoutHandler interface { + // OnReassemblyTimeout will be called with the first fragment (or nil, if the + // first fragment has not been received) of a packet whose reassembly has + // timed out. + OnReassemblyTimeout(pkt *stack.PacketBuffer) } // NewFragmentation creates a new Fragmentation. @@ -97,7 +106,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -111,12 +120,13 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea } f := &Fragmentation{ - reassemblers: make(map[FragmentID]*reassembler), - highLimit: highMemoryLimit, - lowLimit: lowMemoryLimit, - timeout: reassemblingTimeout, - blockSize: blockSize, - clock: clock, + reassemblers: make(map[FragmentID]*reassembler), + highLimit: highMemoryLimit, + lowLimit: lowMemoryLimit, + timeout: reassemblingTimeout, + blockSize: blockSize, + clock: clock, + timeoutHandler: timeoutHandler, } f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) @@ -136,16 +146,8 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // proto is the protocol number marked in the fragment being processed. It has // to be given here outside of the FragmentID struct because IPv6 should not use // the protocol to identify a fragment. -// -// releaseCB is a callback that will run when the fragment reassembly of a -// packet is complete or cancelled. releaseCB take a a boolean argument which is -// true iff the reassembly is cancelled due to timeout. releaseCB should be -// passed only with the first fragment of a packet. If more than one releaseCB -// are passed for the same packet, only the first releaseCB will be saved for -// the packet and the succeeding ones will be dropped by running them -// immediately with a false argument. func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) ( + id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( buffer.VectorisedView, uint8, bool, error) { if first > last { return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) @@ -160,10 +162,9 @@ func (f *Fragmentation) Process( return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } - if l := vv.Size(); l < int(fragmentSize) { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + if l := pkt.Data.Size(); l != int(fragmentSize) { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } - vv.CapLength(int(fragmentSize)) f.mu.Lock() r, ok := f.reassemblers[id] @@ -179,15 +180,9 @@ func (f *Fragmentation) Process( f.releaseReassemblersLocked() } } - if releaseCB != nil { - if !r.setCallback(releaseCB) { - // We got a duplicate callback. Release it immediately. - releaseCB(false /* timedOut */) - } - } f.mu.Unlock() - res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) + res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. @@ -231,7 +226,9 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) { f.size = 0 } - r.release(timedOut) // releaseCB may run. + if h := f.timeoutHandler; timedOut && h != nil { + h.OnReassemblyTimeout(r.pkt) + } } // releaseReassemblersLocked releases already-expired reassemblers, then diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 5dcd10730..3a79688a8 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // reassembleTimeout is dummy timeout used for testing, where the clock never @@ -40,13 +41,19 @@ func vv(size int, pieces ...string) buffer.VectorisedView { return buffer.NewVectorisedView(size, views) } +func pkt(size int, pieces ...string) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv(size, pieces...), + }) +} + type processInput struct { id FragmentID first uint16 last uint16 more bool proto uint8 - vv buffer.VectorisedView + pkt *stack.PacketBuffer } type processOutput struct { @@ -63,8 +70,8 @@ var processTestCases = []struct { { comment: "One ID", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -74,8 +81,8 @@ var processTestCases = []struct { { comment: "Next Header protocol mismatch", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -85,10 +92,10 @@ var processTestCases = []struct { { comment: "Two IDs", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, + {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -102,17 +109,17 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) + t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, err) } if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)", - in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView()) + t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)", + in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView()) } if done != c.out[i].done { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", @@ -236,11 +243,11 @@ func TestReassemblingTimeout(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock) + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) for _, event := range test.events { clock.Advance(event.clockAdvance) if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil) + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -257,17 +264,17 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -281,11 +288,11 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) got := f.size want := 1 @@ -327,6 +334,7 @@ func TestErrors(t *testing.T) { last: 3, more: true, data: "012", + err: ErrInvalidArgs, }, { name: "exact block size with more and too little data", @@ -376,8 +384,8 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil) + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } @@ -498,57 +506,92 @@ func TestPacketFragmenter(t *testing.T) { } } -func TestReleaseCallback(t *testing.T) { +type testTimeoutHandler struct { + pkt *stack.PacketBuffer +} + +func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + h.pkt = pkt +} + +func TestTimeoutHandler(t *testing.T) { const ( proto = 99 ) - var result int - var callbackReasonIsTimeout bool - cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut } - cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut } + pk1 := pkt(1, "1") + pk2 := pkt(1, "2") + + type processParam struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + } tests := []struct { - name string - callbacks []func(bool) - timeout bool - wantResult int - wantCallbackReasonIsTimeout bool + name string + params []processParam + wantError bool + wantPkt *stack.PacketBuffer }{ { - name: "callback runs on release", - callbacks: []func(bool){cb1}, - timeout: false, - wantResult: 1, - wantCallbackReasonIsTimeout: false, - }, - { - name: "first callback is nil", - callbacks: []func(bool){nil, cb2}, - timeout: false, - wantResult: 2, - wantCallbackReasonIsTimeout: false, + name: "onTimeout runs", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: pk1, }, { - name: "two callbacks - first one is set", - callbacks: []func(bool){cb1, cb2}, - timeout: false, - wantResult: 1, - wantCallbackReasonIsTimeout: false, + name: "no first fragment", + params: []processParam{ + { + first: 1, + last: 1, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: nil, }, { - name: "callback runs on timeout", - callbacks: []func(bool){cb1}, - timeout: true, - wantResult: 1, - wantCallbackReasonIsTimeout: true, + name: "second pkt is ignored", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + { + first: 0, + last: 0, + more: true, + pkt: pk2, + }, + }, + wantError: false, + wantPkt: pk1, }, { - name: "no callbacks", - callbacks: []func(bool){nil}, - timeout: false, - wantResult: 0, - wantCallbackReasonIsTimeout: false, + name: "invalid args - first is greater than last", + params: []processParam{ + { + first: 1, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: true, + wantPkt: nil, }, } @@ -556,29 +599,31 @@ func TestReleaseCallback(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - result = 0 - callbackReasonIsTimeout = false + handler := &testTimeoutHandler{pkt: nil} - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) - for i, cb := range test.callbacks { - _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb) - if err != nil { + for _, p := range test.params { + if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { t.Errorf("f.Process error = %s", err) } } - - r, ok := f.reassemblers[id] - if !ok { - t.Fatalf("Reassemberr not found") - } - f.release(r, test.timeout) - - if result != test.wantResult { - t.Errorf("got result = %d, want = %d", result, test.wantResult) + if !test.wantError { + r, ok := f.reassemblers[id] + if !ok { + t.Fatal("Reassembler not found") + } + f.release(r, true) } - if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout { - t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout) + switch { + case handler.pkt != nil && test.wantPkt == nil: + t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) + case handler.pkt == nil && test.wantPkt != nil: + t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) + case handler.pkt != nil && test.wantPkt != nil: + if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { + t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) + } } }) } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index c0cc0bde0..19f4920b3 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) type hole struct { @@ -41,7 +42,7 @@ type reassembler struct { heap fragHeap done bool creationTime int64 - callback func(bool) + pkt *stack.PacketBuffer } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { @@ -79,7 +80,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -89,18 +90,20 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf // was waiting on the mutex. We don't have to do anything in this case. return buffer.VectorisedView{}, 0, false, consumed, nil } - // For IPv6, it is possible to have different Protocol values between - // fragments of a packet (because, unlike IPv4, the Protocol is not used to - // identify a fragment). In this case, only the Protocol of the first - // fragment must be used as per RFC 8200 Section 4.5. - // - // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded - // here (instead of just the protocol) because most IP options should be - // derived from the first fragment. - if first == 0 { - r.proto = proto - } if r.updateHoles(first, last, more) { + // For IPv6, it is possible to have different Protocol values between + // fragments of a packet (because, unlike IPv4, the Protocol is not used to + // identify a fragment). In this case, only the Protocol of the first + // fragment must be used as per RFC 8200 Section 4.5. + // + // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP + // options received in the first fragment should be used - and they should + // override options from following fragments. + if first == 0 { + r.pkt = pkt + r.proto = proto + } + vv := pkt.Data // We store the incoming packet only if it filled some holes. heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) consumed = vv.Size() @@ -124,24 +127,3 @@ func (r *reassembler) checkDoneOrMark() bool { r.mu.Unlock() return prev } - -func (r *reassembler) setCallback(c func(bool)) bool { - r.mu.Lock() - defer r.mu.Unlock() - if r.callback != nil { - return false - } - r.callback = c - return true -} - -func (r *reassembler) release(timedOut bool) { - r.mu.Lock() - callback := r.callback - r.callback = nil - r.mu.Unlock() - - if callback != nil { - callback(timedOut) - } -} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index fa2a70dc8..a0a04a027 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -105,26 +105,3 @@ func TestUpdateHoles(t *testing.T) { } } } - -func TestSetCallback(t *testing.T) { - result := 0 - reasonTimeout := false - - cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut } - cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut } - - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - if !r.setCallback(cb1) { - t.Errorf("setCallback failed") - } - if r.setCallback(cb2) { - t.Errorf("setCallback should fail if one is already set") - } - r.release(true) - if result != 1 { - t.Errorf("got result = %d, want = 1", result) - } - if !reasonTimeout { - t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout) - } -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index c7d26e14f..787399e08 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -34,16 +35,16 @@ import ( ) const ( - localIPv4Addr = "\x0a\x00\x00\x01" - remoteIPv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") + remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") + ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") + ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") + ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") + localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") + ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") + ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") nicID = 1 ) @@ -192,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu panic("not implemented") } -func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - return tcpip.ErrNotSupported -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*testObject) ARPHardwareType() header.ARPHardwareType { panic("not implemented") @@ -299,6 +296,10 @@ func (t *testInterface) Enabled() bool { return !t.mu.disabled } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) setEnabled(v bool) { t.mu.Lock() defer t.mu.Unlock() @@ -558,59 +559,135 @@ func TestIPv4Send(t *testing.T) { } } -func TestIPv4Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, +func TestReceive(t *testing.T) { + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + v4 bool + epAddr tcpip.AddressWithPrefix + handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + v4: true, + epAddr: localIPv4Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const totalLen = header.IPv4MinimumSize + 30 /* payload length */ + + view := buffer.NewView(totalLen) + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + TTL: ipv4.DefaultTTL, + Protocol: 10, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if ok := parse.IPv4(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + v4: false, + epAddr: localIPv6Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const payloadLen = 30 + view := buffer.NewView(header.IPv6MinimumSize + payloadLen) + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadLen, + NextHeader: 10, + HopLimit: ipv6.DefaultTTL, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, + }) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } + // Make payload be non-zero. + for i := header.IPv6MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } - totalLen := header.IPv4MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) + // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, _, _, ok := parse.IPv6(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, + }, } - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + nic := testInterface{ + testObject: testObject{ + t: t, + v4: test.v4, + }, + } + ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject) + defer ep.Close() - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - r.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) + } + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + } else { + ep.DecRef() + } + + stat := s.Stats().IP.PacketsReceived + if got := stat.Value(); got != 0 { + t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) + } + test.handlePacket(t, ep, &nic) + if nic.testObject.dataCalls != 1 { + t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + } + if got := stat.Value(); got != 1 { + t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) + } + }) } } @@ -634,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -705,8 +778,18 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv4Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) @@ -716,7 +799,9 @@ func TestIPv4ReceiveControl(t *testing.T) { } func TestIPv4FragmentationReceive(t *testing.T) { - s := buildDummyStack(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ testObject: testObject{ @@ -774,11 +859,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { nic.testObject.dstAddr = localIPv4Addr nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - // Send first segment. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag1.ToVectorisedView(), @@ -786,7 +866,18 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - r.PopulatePacketInfo(pkt) + + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv4Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) @@ -799,7 +890,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) @@ -852,61 +942,6 @@ func TestIPv6Send(t *testing.T) { } } -func TestIPv6Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - totalLen := header.IPv6MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(totalLen - header.IPv6MinimumSize), - NextHeader: 10, - HopLimit: 20, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] - - r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - r.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } -} - func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } @@ -933,13 +968,6 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } - r, err := buildIPv6Route( - localIPv6Addr, - "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - ) - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -1018,8 +1046,17 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv6Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) @@ -1202,7 +1239,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ipHdrLen := header.IPv4MinimumSize + ipv4Options.AllocationSize() + ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding() totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1247,7 +1284,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.AllocationSize())) + ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding())) ip.Encode(&header.IPv4Fields{ Protocol: transportProto, TTL: ipv4.DefaultTTL, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 9b5e37fee..488945226 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -90,7 +90,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { iph := header.IPv4(pkt.NetworkHeader().View()) var newOptions header.IPv4Options - if len(iph) > header.IPv4MinimumSize { + if opts := iph.Options(); len(opts) != 0 { // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip // type ICMP packets): // If a Record Route and/or Time Stamp option is received in an @@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op) + aux, tmp, err := e.processIPOptions(pkt, opts, op) if err != nil { switch { case @@ -290,6 +290,13 @@ type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in +// transit to its final destination, as per RFC 792 page 6, Time Exceeded +// Message. +type icmpReasonTTLExceeded struct{} + +func (*icmpReasonTTLExceeded) isICMPReason() {} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -342,11 +349,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi return nil } + // If we hit a TTL Exceeded error, then we know we are operating as a router. + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + // + // ... + // + // Code 0 may be received from a gateway. ... + // + // Note, Code 0 is the TTL exceeded error. + // + // If we are operating as a router/gateway, don't use the packet's destination + // address as the response's source address as we should not not own the + // destination address of a packet we are forwarding. + localAddr := origIPHdrDst + if _, ok := reason.(*icmpReasonTTLExceeded); ok { + localAddr = "" + } // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } @@ -454,6 +481,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) counter = sent.DstUnreachable + case *icmpReasonTTLExceeded: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4TTLExceeded) + counter = sent.TimeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) @@ -483,3 +514,18 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi counter.Increment() return nil } + +// OnReassemblyTimeout implements fragmentation.TimeoutHandler. +func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + // OnReassemblyTimeout sends a Time Exceeded Message, as per RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards the + // datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + if pkt != nil { + p.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a376cb8ec..1efe6297a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -206,12 +206,12 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s if opts, ok = params.Options.(header.IPv4Options); !ok { panic(fmt.Sprintf("want IPv4Options, got %T", params.Options)) } - hdrLen += opts.AllocationSize() + hdrLen += opts.SizeWithPadding() if hdrLen > header.IPv4MaximumHeaderSize { // Since we have no way to report an error we must either panic or create // a packet which is different to what was requested. Choose panic as this // would be a programming error that should be caught in testing. - panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.AllocationSize(), header.IPv4MaximumOptionsSize)) + panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize)) } } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) @@ -260,16 +260,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt) -} -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -286,24 +283,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -374,8 +374,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -400,9 +399,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -479,16 +479,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt) + return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv4(pkt.NetworkHeader().View()) + ttl := h.TTL() + if ttl == 0 { + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + } + + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader())) + + // As per RFC 791 page 30, Time to Live, + // + // This field must be decreased at each point that the internet header + // is processed to reflect the time spent processing the datagram. + // Even if no local information is available on the time actually + // spent, the field must be decremented by 1. + newHdr.SetTTL(ttl - 1) + + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(newHdr).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -497,6 +566,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + + addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } + subnet := addressEndpoint.AddressWithPrefix().Subnet() + addressEndpoint.DecRef() // There has been some confusion regarding verifying checksums. We need // just look for negative 0 (0xffff) as the checksum, as it's not possible to @@ -528,15 +612,16 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) { + if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { stats.IP.InvalidSourceAddressesReceived.Increment() return } + pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return @@ -565,29 +650,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Set up a callback in case we need to send a Time Exceeded Message, as per - // RFC 792: - // - // If a host reassembling a fragmented datagram cannot complete the - // reassembly due to missing fragments within its time limit it discards - // the datagram, and it may send a time exceeded message. - // - // If fragment zero is not available then no time exceeded need be sent at - // all. - var releaseCB func(bool) - if start == 0 { - pkt := pkt.Clone() - releaseCB = func(timedOut bool) { - if timedOut { - _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) - } - } - } - - var ready bool - var err error proto := h.Protocol() - pkt.Data, _, ready, err = e.protocol.fragmentation.Process( + data, _, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -600,8 +664,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { start+uint16(pkt.Data.Size())-1, h.More(), proto, - pkt.Data, - releaseCB, + pkt, ) if err != nil { stats.IP.MalformedPacketsReceived.Increment() @@ -611,6 +674,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !ready { return } + pkt.Data = data // The reassembler doesn't take care of fixing up the header, so we need // to do it here. @@ -628,11 +692,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handleICMP(pkt) return } - if len(h.Options()) != 0 { + if opts := h.Options(); len(opts) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options // rather than just throwing them away. - aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{}) + aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{}) if err != nil { switch { case @@ -778,6 +842,7 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) +var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack @@ -942,13 +1007,14 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ - stack: s, - ids: ids, - hashIV: hashIV, - defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), + p := &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, } + p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + return p } func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index c6e565455..4e4e1f3b4 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -103,6 +103,262 @@ func TestExcludeBroadcast(t *testing.T) { }) } +// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested +// fields when options are supplied. +func TestIPv4EncodeOptions(t *testing.T) { + tests := []struct { + name string + options header.IPv4Options + encodedOptions header.IPv4Options // reply should look like this + wantIHL int + }{ + { + name: "valid no options", + wantIHL: header.IPv4MinimumSize, + }, + { + name: "one byte options", + options: header.IPv4Options{1}, + encodedOptions: header.IPv4Options{1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "two byte options", + options: header.IPv4Options{1, 1}, + encodedOptions: header.IPv4Options{1, 1, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "three byte options", + options: header.IPv4Options{1, 1, 1}, + encodedOptions: header.IPv4Options{1, 1, 1, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "four byte options", + options: header.IPv4Options{1, 1, 1, 1}, + encodedOptions: header.IPv4Options{1, 1, 1, 1}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "five byte options", + options: header.IPv4Options{1, 1, 1, 1, 1}, + encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 8, + }, + { + name: "thirty nine byte options", + options: header.IPv4Options{ + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, + }, + encodedOptions: header.IPv4Options{ + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 0, + }, + wantIHL: header.IPv4MinimumSize + 40, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + paddedOptionLength := test.options.SizeWithPadding() + ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength) + hdr := buffer.NewPrependable(int(totalLen)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + // To check the padding works, poison the last byte of the options space. + if paddedOptionLength != len(test.options) { + ip.SetHeaderLength(uint8(ipHeaderLength)) + ip.Options()[paddedOptionLength-1] = 0xff + ip.SetHeaderLength(0) + } + ip.Encode(&header.IPv4Fields{ + Options: test.options, + }) + options := ip.Options() + wantOptions := test.encodedOptions + if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { + t.Errorf("got IHL of %d, want %d", got, want) + } + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(wantOptions) == 0 && len(options) == 0 { + return + } + + if diff := cmp.Diff(wantOptions, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + ) + + ipv4Addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + } + ipv4Addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + } + remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) + remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + + tests := []struct { + name string + TTL uint8 + expectErrorICMP bool + }{ + { + name: "TTL of zero", + TTL: 0, + expectErrorICMP: true, + }, + { + name: "TTL of one", + TTL: 1, + expectErrorICMP: false, + }, + { + name: "TTL of two", + TTL: 2, + expectErrorICMP: false, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectErrorICMP: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e1 := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} + if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) + } + + e2 := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} + if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Addr1.Subnet(), + NIC: nicID1, + }, + { + Destination: ipv4Addr2.Subnet(), + NIC: nicID2, + }, + }) + + if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + } + + totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: test.TTL, + SrcAddr: remoteIPv4Addr1, + DstAddr: remoteIPv4Addr2, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + if test.expectErrorICMP { + reply, ok := e1.Read() + if !ok { + t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(ipv4Addr1.Address), + checker.DstAddr(remoteIPv4Addr1), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4TTLExceeded), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + + if n := e2.Drain(); n != 0 { + t.Fatalf("got e2.Drain() = %d, want = 0", n) + } + } else { + reply, ok := e2.Read() + if !ok { + t.Fatal("expected ICMP Echo packet through outgoing NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(remoteIPv4Addr1), + checker.DstAddr(remoteIPv4Addr2), + checker.TTL(test.TTL-1), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Payload(nil), + ), + ) + + if n := e1.Drain(); n != 0 { + t.Fatalf("got e1.Drain() = %d, want = 0", n) + } + } + }) + } +} + // TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and // checks the response. func TestIPv4Sanity(t *testing.T) { @@ -197,6 +453,14 @@ func TestIPv4Sanity(t *testing.T) { replyOptions: header.IPv4Options{1, 1, 0, 0}, }, { + name: "Check option padding", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{1, 1, 1}, + replyOptions: header.IPv4Options{1, 1, 1, 0}, + }, + { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, maxTotalLength: ipv4.MaxTotalSize, @@ -599,9 +863,10 @@ func TestIPv4Sanity(t *testing.T) { }, }) - ipHeaderLength := header.IPv4MinimumSize + test.options.AllocationSize() + paddedOptionLength := test.options.SizeWithPadding() + ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) @@ -618,6 +883,12 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } + // To check the padding works, poison the options space. + if paddedOptionLength != len(test.options) { + ip.SetHeaderLength(uint8(ipHeaderLength)) + ip.Options()[paddedOptionLength-1] = 0x01 + } + ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: test.transportProtocol, @@ -732,7 +1003,7 @@ func TestIPv4Sanity(t *testing.T) { } // If the IP options change size then the packet will change size, so // some IP header fields will need to be adjusted for the checks. - sizeChange := len(test.replyOptions) - len(test.options) + sizeChange := len(test.replyOptions) - paddedOptionLength checker.IPv4(t, replyIPHeader, checker.IPv4HeaderLength(ipHeaderLength+sizeChange), @@ -2441,9 +2712,6 @@ func TestPacketQueing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 8502b848c..beb8f562e 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -750,6 +750,12 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in +// transit to its final destination, as per RFC 4443 section 3.3. +type icmpReasonHopLimitExceeded struct{} + +func (*icmpReasonHopLimitExceeded) isICMPReason() {} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -794,11 +800,27 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi return nil } + // If we hit a Hop Limit Exceeded error, then we know we are operating as a + // router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard the + // packet and originate an ICMPv6 Time Exceeded message with Code 0 to + // the source of the packet. This indicates either a routing loop or + // too small an initial Hop Limit value. + // + // If we are operating as a router, do not use the packet's destination + // address as the response's source address as we should not own the + // destination address of a packet we are forwarding. + localAddr := origIPHdrDst + if _, ok := reason.(*icmpReasonHopLimitExceeded); ok { + localAddr = "" + } // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } @@ -811,8 +833,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi return nil } - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. // Unfortunately at this time ICMP Packets do not have a transport @@ -830,6 +850,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } } + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + // As per RFC 4443 section 2.4 // // (c) Every ICMPv6 error message (type < 128) MUST include @@ -873,6 +895,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.DstUnreachable + case *icmpReasonHopLimitExceeded: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) + counter = sent.TimeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) @@ -896,3 +922,16 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi counter.Increment() return nil } + +// OnReassemblyTimeout implements fragmentation.TimeoutHandler. +func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + // OnReassemblyTimeout sends a Time Exceeded Message as per RFC 2460 Section + // 4.5: + // + // If the first fragment (i.e., the one with a Fragment Offset of zero) has + // been received, an ICMP Time Exceeded -- Fragment Reassembly Time Exceeded + // message should be sent to the source of that fragment. + if pkt != nil { + p.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 76013daa1..9bc02d851 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -144,6 +144,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ NetProto: protocol, @@ -174,13 +178,8 @@ func TestICMPCounts(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -206,11 +205,16 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -279,10 +283,9 @@ func TestICMPCounts(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -290,7 +293,7 @@ func TestICMPCounts(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -317,13 +320,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -349,11 +347,16 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -422,10 +425,9 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -433,7 +435,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -1775,17 +1777,19 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() - - // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. - r.LocalAddress = test.destination icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.IPv6MinimumSize, Data: buffer.View(icmp).ToVectorisedView(), @@ -1795,10 +1799,9 @@ func TestCallsToNeighborCache(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.RemoteAddress, - DstAddr: r.LocalAddress, + SrcAddr: test.source, + DstAddr: test.destination, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0526190cc..7a00f6314 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -441,17 +441,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt, params.Protocol) -} -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -467,24 +463,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -558,8 +557,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -584,9 +582,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -640,16 +639,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt, proto) + return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv6(pkt.NetworkHeader().View()) + hopLimit := h.HopLimit() + if hopLimit <= 1 { + // As per RFC 4443 section 3.3, + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard the + // packet and originate an ICMPv6 Time Exceeded message with Code 0 to + // the source of the packet. This indicates either a routing loop or + // too small an initial Hop Limit value. + return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + } + + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader())) + + // As per RFC 8200 section 3, + // + // Hop Limit 8-bit unsigned integer. Decremented by 1 by + // each node that forwards the packet. + newHdr.SetHopLimit(hopLimit - 1) + + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(newHdr).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -669,6 +737,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } + addressEndpoint.DecRef() + // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. @@ -681,8 +761,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return @@ -888,18 +967,6 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Set up a callback in case we need to send a Time Exceeded Message as - // per RFC 2460 Section 4.5. - var releaseCB func(bool) - if start == 0 { - pkt := pkt.Clone() - releaseCB = func(timedOut bool) { - if timedOut { - _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) - } - } - } - // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. data, proto, ready, err := e.protocol.fragmentation.Process( @@ -914,17 +981,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { start+uint16(fragmentPayloadLen)-1, extHdr.More(), uint8(rawPayload.Identifier), - rawPayload.Buf, - releaseCB, + pkt, ) if err != nil { stats.IP.MalformedPacketsReceived.Increment() stats.IP.MalformedFragmentsReceived.Increment() return } - pkt.Data = data if ready { + pkt.Data = data + // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC // 8200 section 4.5. We also use the NextHeader value from the first @@ -1335,6 +1402,7 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) +var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack @@ -1590,10 +1658,9 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ - stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), - ids: ids, - hashIV: hashIV, + stack: s, + ids: ids, + hashIV: hashIV, ndpDisp: opts.NDPDisp, ndpConfigs: opts.NDPConfigs, @@ -1601,6 +1668,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { tempIIDSeed: opts.TempIIDSeed, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, } + p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[*endpoint]struct{}) p.SetDefaultTTL(DefaultTTL) return p diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1bfcdde25..a671d4bac 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "fmt" "math" + "net" "testing" "github.com/google/go-cmp/cmp" @@ -2821,3 +2822,160 @@ func TestFragmentationErrors(t *testing.T) { }) } } + +func TestForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + ) + + ipv6Addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + } + ipv6Addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + } + remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) + remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + + tests := []struct { + name string + TTL uint8 + expectErrorICMP bool + }{ + { + name: "TTL of zero", + TTL: 0, + expectErrorICMP: true, + }, + { + name: "TTL of one", + TTL: 1, + expectErrorICMP: true, + }, + { + name: "TTL of two", + TTL: 2, + expectErrorICMP: false, + }, + { + name: "TTL of three", + TTL: 3, + expectErrorICMP: false, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectErrorICMP: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} + if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} + if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: ipv6Addr1.Subnet(), + NIC: nicID1, + }, + { + Destination: ipv6Addr2.Subnet(), + NIC: nicID2, + }, + }) + + if err := s.SetForwarding(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + } + + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) + icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv6EchoRequest) + icmp.SetCode(header.ICMPv6UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: test.TTL, + SrcAddr: remoteIPv6Addr1, + DstAddr: remoteIPv6Addr2, + }) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e1.InjectInbound(ProtocolNumber, requestPkt) + + if test.expectErrorICMP { + reply, ok := e1.Read() + if !ok { + t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + } + + checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(ipv6Addr1.Address), + checker.DstAddr(remoteIPv6Addr1), + checker.TTL(DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), + checker.ICMPv6Payload([]byte(hdr.View())), + ), + ) + + if n := e2.Drain(); n != 0 { + t.Fatalf("got e2.Drain() = %d, want = 0", n) + } + } else { + reply, ok := e2.Read() + if !ok { + t.Fatal("expected ICMP Echo Request packet through outgoing NIC") + } + + checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(remoteIPv6Addr1), + checker.DstAddr(remoteIPv6Addr2), + checker.TTL(test.TTL-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest), + checker.ICMPv6Code(header.ICMPv6UnusedCode), + checker.ICMPv6Payload(nil), + ), + ) + + if n := e1.Drain(); n != 0 { + t.Fatalf("got e1.Drain() = %d, want = 0", n) + } + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 981d1371a..37e8b1083 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } - { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { @@ -73,6 +69,17 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig } t.Cleanup(ep.Close) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := llladdr.WithPrefix() + if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + addressEP.DecRef() + } + return s, ep } @@ -961,22 +968,17 @@ func TestNDPValidation(t *testing.T) { for _, stackTyp := range stacks { t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() // Create a stack with the assigned link-local address lladdr0 // and an endpoint to lladdr1. s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r + return s, ep } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { nextHdr := uint8(header.ICMPv6ProtocolNumber) var extensions buffer.View if atomicFragment { @@ -994,13 +996,12 @@ func TestNDPValidation(t *testing.T) { PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -1114,8 +1115,7 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + s, ep := setup(t) if isRouter { // Enabling forwarding makes the stack act as a router. @@ -1131,7 +1131,7 @@ func TestNDPValidation(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -1152,7 +1152,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 7cc52985e..5c3363759 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st return n, nil } -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - - return nil -} - // Attach implements LinkEndpoint.Attach. func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 8e0ee1cd7..1c2afd554 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -148,10 +148,6 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - log.Fatal(err) - } - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) if err != nil { log.Fatal(err) diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go new file mode 100644 index 000000000..e1b0d6354 --- /dev/null +++ b/pkg/tcpip/socketops.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "sync/atomic" +) + +// SocketOptions contains all the variables which store values for SOL_SOCKET +// level options. +// +// +stateify savable +type SocketOptions struct { + // These fields are accessed and modified using atomic operations. + + // broadcastEnabled determines whether datagram sockets are allowed to send + // packets to a broadcast address. + broadcastEnabled uint32 + + // passCredEnabled determines whether SCM_CREDENTIALS socket control messages + // are enabled. + passCredEnabled uint32 +} + +func storeAtomicBool(addr *uint32, v bool) { + var val uint32 + if v { + val = 1 + } + atomic.StoreUint32(addr, val) +} + +// GetBroadcast gets value for SO_BROADCAST option. +func (so *SocketOptions) GetBroadcast() bool { + return atomic.LoadUint32(&so.broadcastEnabled) != 0 +} + +// SetBroadcast sets value for SO_BROADCAST option. +func (so *SocketOptions) SetBroadcast(v bool) { + storeAtomicBool(&so.broadcastEnabled, v) +} + +// GetPassCred gets value for SO_PASSCRED option. +func (so *SocketOptions) GetPassCred() bool { + return atomic.LoadUint32(&so.passCredEnabled) != 0 +} + +// SetPassCred sets value for SO_PASSCRED option. +func (so *SocketOptions) SetPassCred(v bool) { + storeAtomicBool(&so.passCredEnabled, v) +} diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7a501acdc..cb7dec1ea 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -74,8 +74,30 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + netHdr := pkt.NetworkHeader().View() + _, dst := f.proto.ParseAddresses(netHdr) + + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) + return + } + + r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) + if err != nil { + return + } + defer r.Release() + + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + pkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: vv.ToView().ToVectorisedView(), + }) + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + _ = r.WriteHeaderIncludedPacket(pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { + // The network header should not already be populated. + if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { + return tcpip.ErrMalformedHeader + } + + return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() { // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { + stack *Stack + addrCache *linkAddrCache neigh *neighborCache addrResolveDelay time.Duration @@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := fwdTestPacketInfo{ - Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), - } - - select { - case e.C <- p: - default: - } - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} @@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { + proto.stack = s + return proto + }}, UseNeighborCache: useNeighborCache, }) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 60c81a3aa..3e6ceff28 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -232,7 +232,8 @@ func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Unlock() } -func (n *NIC) isPromiscuousMode() bool { +// Promiscuous implements NetworkInterface. +func (n *NIC) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -320,16 +321,21 @@ func (n *NIC) setSpoofing(enable bool) { // primaryAddress returns an address that can be used to communicate with // remoteAddr. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { - n.mu.RLock() - spoofing := n.mu.spoofing - n.mu.RUnlock() - ep, ok := n.networkEndpoints[protocol] if !ok { return nil } - return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + + return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } type getAddressBehaviour int @@ -388,11 +394,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre // getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean // is passed to indicate whether or not we should generate temporary endpoints. func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { - if ep, ok := n.networkEndpoints[protocol]; ok { - return ep.AcquireAssignedAddress(address, createTemp, peb) + ep, ok := n.networkEndpoints[protocol] + if !ok { + return nil } - return nil + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb) } // addAddress adds a new address to n, so that it starts accepting packets @@ -403,7 +415,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return tcpip.ErrUnknownProtocol } - addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -416,7 +433,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PermanentAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PermanentAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -427,7 +449,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PrimaryAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PrimaryAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -445,13 +472,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit return tcpip.AddressWithPrefix{} } - return ep.MainAddress() + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.AddressWithPrefix{} + } + + return addressableEndpoint.MainAddress() } // removeAddress removes an address from n. func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { for _, ep := range n.networkEndpoints { - if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { continue } else { return err @@ -564,13 +601,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return false } -func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) - defer r.Release() - r.PopulatePacketInfo(pkt) - n.getNetworkEndpoint(protocol).HandlePacket(pkt) -} - // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the link endpoint. @@ -592,7 +622,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) - netProto, ok := n.stack.networkProtocols[protocol] + networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -617,11 +647,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp ep.HandlePacket(n.id, local, protocol, p) } - if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { - n.stack.stats.IP.PacketsReceived.Increment() - } - // Parse headers. + netProto := n.stack.NetworkProtocolInstance(protocol) transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { // The packet is too small to contain a network header. @@ -636,9 +663,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.IsLoopback() { + src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) if r := n.getAddress(protocol, src); r != nil { r.DecRef() @@ -651,78 +677,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - // Loopback traffic skips the prerouting chain. - if !n.IsLoopback() { - // iptables filtering. - ipt := n.stack.IPTables() - address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { - // iptables is telling us to drop the packet. - n.stack.stats.IP.IPTablesPreroutingDropped.Increment() - return - } - } - - if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { - n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) - return - } - - // This NIC doesn't care about the packet. Find a NIC that cares about the - // packet and forward it to the NIC. - // - // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding(protocol) { - r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) - if err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - // Found a NIC. - n := r.localAddressNIC - if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { - if n.isValidForOutgoing(addressEndpoint) { - pkt.NICID = n.ID() - r.RemoteAddress = src - pkt.NetworkPacketInfo = r.networkPacketInfo() - n.getNetworkEndpoint(protocol).HandlePacket(pkt) - addressEndpoint.DecRef() - r.Release() - return - } - - addressEndpoint.DecRef() - } - - // n doesn't have a destination endpoint. - // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease - // the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - // We need to do a deep copy of the IP packet because WritePacket (and - // friends) take ownership of the packet buffer, but we do not own it. - Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } - - r.Release() - return - } - - // If a packet socket handled the packet, don't treat it as invalid. - if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } + networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 00e9a82ae..43ca03ada 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -65,10 +65,6 @@ const ( // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { - // RemoteAddressBroadcast is true if the packet's remote address is a - // broadcast address. - RemoteAddressBroadcast bool - // LocalAddressBroadcast is true if the packet's local address is a broadcast // address. LocalAddressBroadcast bool @@ -266,10 +262,10 @@ const ( // NetOptions is an interface that allows us to pass network protocol specific // options through the Stack layer code. type NetOptions interface { - // AllocationSize returns the amount of memory that must be allocated to + // SizeWithPadding returns the amount of memory that must be allocated to // hold the options given that the value must be rounded up to the next // multiple of 4 bytes. - AllocationSize() int + SizeWithPadding() int } // NetworkHeaderParams are the header parameters given as input by the @@ -518,6 +514,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + // Promiscuous returns true if the interface is in promiscuous mode. + Promiscuous() bool + // WritePacketToRemote writes the packet to the given remote link address. WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } @@ -525,8 +524,6 @@ type NetworkInterface interface { // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { - AddressableEndpoint - // Enable enables the endpoint. // // Must only be called when the stack is in a state that allows the endpoint @@ -742,10 +739,6 @@ type LinkEndpoint interface { // endpoint. Capabilities() LinkEndpointCapabilities - // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. It takes ownership of vv. - WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error - // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 15ff437c7..53cb6694f 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -170,28 +170,6 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -// PopulatePacketInfo populates a packet buffer's packet information fields. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { - if r.local() { - pkt.RXTransportChecksumValidated = true - } - pkt.NetworkPacketInfo = r.networkPacketInfo() -} - -// networkPacketInfo returns the network packet information of the route. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) networkPacketInfo() NetworkPacketInfo { - return NetworkPacketInfo{ - RemoteAddressBroadcast: r.IsOutboundBroadcast(), - LocalAddressBroadcast: r.isInboundBroadcast(), - } -} - // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { return r.outgoingNIC.ID() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0fe157128..f4504e633 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -82,6 +82,7 @@ type TCPRACKState struct { FACK seqnum.Value RTT time.Duration Reord bool + DSACKSeen bool } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -1080,7 +1081,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. Running: nic.Enabled(), - Promiscuous: nic.isPromiscuousMode(), + Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ @@ -1809,49 +1810,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip nic.unregisterPacketEndpoint(netProto, ep) } -// WritePacket writes data directly to the specified NIC. It adds an ethernet -// header based on the arguments. -func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +// WritePacketToRemote writes a payload on the specified NIC using the provided +// network protocol and remote link address. +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice } - - // Add our own fake ethernet header. - ethFields := header.EthernetFields{ - SrcAddr: nic.LinkEndpoint.LinkAddress(), - DstAddr: dst, - Type: netProto, - } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - vv := buffer.View(fakeHeader).ToVectorisedView() - vv.Append(payload) - - if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { - return err - } - - return nil -} - -// WriteRawPacket writes data directly to the specified NIC without adding any -// headers. -func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { - s.mu.Lock() - nic, ok := s.nics[nicID] - s.mu.Unlock() - if !ok { - return tcpip.ErrUnknownDevice - } - - if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { - return err - } - - return nil + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(nic.MaxHeaderLength()), + Data: payload, + }) + return nic.WritePacketToRemote(remote, nil, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dedfdd435..0d94af139 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -112,7 +112,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() - f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ + + dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + return + } + addressEndpoint.DecRef() + + f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { @@ -159,9 +167,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.Clone() - r.PopulatePacketInfo(pkt) - f.HandlePacket(pkt) + f.HandlePacket(pkt.Clone()) } if r.Loop&stack.PacketOut == 0 { return nil @@ -2214,88 +2220,6 @@ func TestNICStats(t *testing.T) { } } -func TestNICForwarding(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const dstAddr = tcpip.Address("\x03") - - tests := []struct { - name string - headerLen uint16 - }{ - { - name: "Zero header length", - }, - { - name: "Non-zero header length", - headerLen: 16, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) - } - - ep2 := channelLinkWithHeaderLength{ - Endpoint: channel.New(10, defaultMTU, ""), - headerLength: test.headerLen, - } - if err := s.CreateNIC(nicID2, &ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) - } - - // Route all packets to dstAddr to NIC 2. - { - subnet, err := tcpip.NewSubnet(dstAddr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) - } - - // Send a packet to dstAddr. - buf := buffer.NewView(30) - buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not forwarded") - } - - // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { - t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - }) - } -} - // TestNICContextPreservation tests that you can read out via stack.NICInfo the // Context data you pass via NICContext.Context in stack.CreateNICWithOptions. func TestNICContextPreservation(t *testing.T) { @@ -4228,3 +4152,63 @@ func TestFindRouteWithForwarding(t *testing.T) { }) } } + +func TestWritePacketToRemote(t *testing.T) { + const nicID = 1 + const MTU = 1280 + e := channel.New(1, MTU, linkAddr1) + s := stack.New(stack.Options{}) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("CreateNIC(%d) = %s", nicID, err) + } + tests := []struct { + name string + protocol tcpip.NetworkProtocolNumber + payload []byte + }{ + { + name: "SuccessIPv4", + protocol: header.IPv4ProtocolNumber, + payload: []byte{1, 2, 3, 4}, + }, + { + name: "SuccessIPv6", + protocol: header.IPv6ProtocolNumber, + payload: []byte{5, 6, 7, 8}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) + } + + pkt, ok := e.Read() + if got, want := ok, true; got != want { + t.Fatalf("e.Read() = %t, want %t", got, want) + } + if got, want := pkt.Proto, test.protocol; got != want { + t.Fatalf("pkt.Proto = %d, want %d", got, want) + } + if got, want := pkt.Route.RemoteLinkAddress, linkAddr2; got != want { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want) + } + if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) + } + }) + } + + t.Run("InvalidNICID", func(t *testing.T) { + if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + } + pkt, ok := e.Read() + if got, want := ok, false; got != want { + t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) + } + }) +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index c457b67a2..5b9043d85 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -47,6 +46,9 @@ type fakeTransportEndpoint struct { // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint + + // ops is used to set and get socket options. + ops tcpip.SocketOptions } func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { @@ -59,6 +61,9 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { + return &f.ops +} func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } @@ -184,9 +189,9 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai if len(f.acceptQueue) == 0 { return nil, nil, nil } - a := f.acceptQueue[0] + a := &f.acceptQueue[0] f.acceptQueue = f.acceptQueue[1:] - return &a, nil, nil + return a, nil, nil } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { @@ -553,87 +558,3 @@ func TestTransportOptions(t *testing.T) { t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } - -func TestTransportForwarding(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - // TODO(b/123449044): Change this to a channel NIC. - ep1 := loopback.New() - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - { - subnet0, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - }) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: req.ToVectorisedView(), - })) - - aep, _, err := ep.Accept(nil) - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - p, ok := ep2.Read() - if !ok { - t.Fatal("Response packet not forwarded") - } - - nh := stack.PayloadSince(p.Pkt.NetworkHeader()) - if dst := nh[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := nh[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3ab2b7654..09361360f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -634,6 +634,10 @@ type Endpoint interface { // LastError clears and returns the last error reported by the endpoint. LastError() *Error + + // SocketOptions returns the structure which contains all the socket + // level options. + SocketOptions() *SocketOptions } // LinkPacketInfo holds Link layer information for a received packet. @@ -694,15 +698,10 @@ type WriteOptions struct { type SockOptBool int const ( - // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify - // whether datagram sockets are allowed to send packets to a broadcast - // address. - BroadcastOption SockOptBool = iota - // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if // data should be held until segments are full by the TCP transport // protocol. - CorkOption + CorkOption SockOptBool = iota // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if // data should be sent out immediately by the transport protocol. For @@ -722,12 +721,6 @@ const ( // whether UDP checksum is disabled for this socket. NoChecksumOption - // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify - // whether SCM_CREDENTIALS socket control messages are enabled. - // - // Only supported on Unix sockets. - PasscredOption - // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. QuickAckOption @@ -1464,9 +1457,13 @@ type ICMPStats struct { // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // PacketsReceived is the total number of IP packets received from the - // link layer in nic.DeliverNetworkPacket. + // link layer. PacketsReceived *StatCounter + // DisabledPacketsReceived is the total number of IP packets received from the + // link layer when the IP layer is disabled. + DisabledPacketsReceived *StatCounter + // InvalidDestinationAddressesReceived is the total number of IP packets // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived *StatCounter diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index bf7594268..39343b966 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -229,19 +229,6 @@ func TestForwarding(t *testing.T) { t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) } - if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index fe7c1bb3d..bf8a1241f 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -140,13 +140,6 @@ func TestPing(t *testing.T) { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } - if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 1eecd7957..9d30329f5 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -514,9 +514,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) } - if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err) - } + ep.SocketOptions().SetBroadcast(true) bindAddr := tcpip.FullAddress{Port: localPort} if bindWildcard { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 763cd8f84..fe6514bcd 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -79,6 +79,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // ops is used to get socket level options. + ops tcpip.SocketOptions } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -853,3 +856,8 @@ func (*endpoint) Wait() {} func (*endpoint) LastError() *tcpip.Error { return nil } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 31831a6d8..3bff3755a 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -89,6 +89,9 @@ type endpoint struct { // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` + + // ops is used to get socket level options. + ops tcpip.SocketOptions } // NewEndpoint returns a new packet endpoint. @@ -549,3 +552,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { } func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} + +func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { + return &ep.ops +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 7b6a87ba9..0a1e1fbb3 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -89,6 +89,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // ops is used to get socket level options. + ops tcpip.SocketOptions } // NewEndpoint returns a raw endpoint for the given protocols. @@ -753,6 +756,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} +// LastError implements tcpip.Endpoint.LastError. func (*endpoint) LastError() *tcpip.Error { return nil } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 47982ca41..6e5adc383 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -235,11 +235,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i return n, nil } -// createEndpointAndPerformHandshake creates a new endpoint in connected state -// and then performs the TCP 3-way handshake. +// startHandshake creates a new endpoint in connecting state and then sends +// the SYN-ACK for the TCP 3-way handshake. It returns the state of the +// handshake in progress, which includes the new endpoint in the SYN-RCVD +// state. // -// The new endpoint is returned with e.mu held. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { +// On success, a handshake h is returned with h.ep.mu held. +// +// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. +func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -257,10 +261,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) if l.listenEP != nil { - l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { - l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() @@ -278,16 +280,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - l.listenEP.mu.Unlock() - } + l.removePendingEndpoint(ep) return nil, tcpip.ErrConnectionAborted } deferAccept = l.listenEP.deferAccept - l.listenEP.mu.Unlock() } // Register new endpoint so that packets are routed to it. @@ -306,28 +304,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.isRegistered = true - // Perform the 3-way handshake. - h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) - if err := h.execute(); err != nil { - ep.mu.Unlock() - ep.Close() - ep.notifyAborted() - - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - - ep.drainClosingSegmentQueue() - + // Initialize and start the handshake. + h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) + if err := h.start(); err != nil { + l.cleanupFailedHandshake(h) return nil, err } - ep.isConnectNotified = true + return h, nil +} - // Update the receive window scaling. We can't do it before the - // handshake because it's possible that the peer doesn't support window - // scaling. - ep.rcv.rcvWndScale = h.effectiveRcvWndScale() +// performHandshake performs a TCP 3-way handshake. On success, the new +// established endpoint is returned with e.mu held. +// +// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. +func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { + h, err := l.startHandshake(s, opts, queue, owner) + if err != nil { + return nil, err + } + ep := h.ep + if err := h.complete(); err != nil { + ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() + ep.stats.FailedConnectionAttempts.Increment() + l.cleanupFailedHandshake(h) + return nil, err + } + l.cleanupCompletedHandshake(h) return ep, nil } @@ -354,6 +357,39 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } +// Precondition: h.ep.mu must be held. +func (l *listenContext) cleanupFailedHandshake(h *handshake) { + e := h.ep + e.mu.Unlock() + e.Close() + e.notifyAborted() + if l.listenEP != nil { + l.removePendingEndpoint(e) + } + e.drainClosingSegmentQueue() + e.h = nil +} + +// cleanupCompletedHandshake transfers any state from the completed handshake to +// the new endpoint. +// +// Precondition: h.ep.mu must be held. +func (l *listenContext) cleanupCompletedHandshake(h *handshake) { + e := h.ep + if l.listenEP != nil { + l.removePendingEndpoint(e) + } + e.isConnectNotified = true + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() + + // Clean up handshake state stored in the endpoint so that it can be GCed. + e.h = nil +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. @@ -433,23 +469,40 @@ func (e *endpoint) notifyAborted() { // // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer ctx.synRcvdCount.dec() +// +// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) + h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - e.decSynRcvdCount() - return + e.synRcvdCount-- + return err } - ctx.removePendingEndpoint(n) - e.decSynRcvdCount() - n.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(n) + go func() { + defer ctx.synRcvdCount.dec() + if err := h.complete(); err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + ctx.cleanupFailedHandshake(h) + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() + return + } + ctx.cleanupCompletedHandshake(h) + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() + h.ep.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(h.ep) + }() // S/R-SAFE: synRcvdCount is the barrier. + + return nil } func (e *endpoint) incSynRcvdCount() bool { @@ -462,12 +515,6 @@ func (e *endpoint) incSynRcvdCount() bool { return canInc } -func (e *endpoint) decSynRcvdCount() { - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() -} - func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) @@ -477,6 +524,8 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. +// +// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed @@ -500,7 +549,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er // backlog. if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() - go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. + _ = e.handleSynSegment(ctx, s, &opts) return nil } ctx.synRcvdCount.dec() @@ -712,7 +761,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { // to the endpoint. e.setEndpointState(StateClose) - // close any endpoints in SYN-RCVD state. + // Close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() // Do cleanup if needed. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6e9015be1..6661e8915 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -102,21 +102,26 @@ type handshake struct { // been received. This is required to stop retransmitting the // original SYN-ACK when deferAccept is enabled. acked bool + + // sendSYNOpts is the cached values for the SYN options to be sent. + sendSYNOpts header.TCPSynOptions } -func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { - h := handshake{ - ep: ep, +func (e *endpoint) newHandshake() *handshake { + h := &handshake{ + ep: e, active: true, - rcvWnd: rcvWnd, - rcvWndScale: ep.rcvWndScaleForHandshake(), + rcvWnd: seqnum.Size(e.initialReceiveWindow()), + rcvWndScale: e.rcvWndScaleForHandshake(), } h.resetState() + // Store reference to handshake state in endpoint. + e.h = h return h } -func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { - h := newHandshake(ep, rcvWnd) +func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake { + h := e.newHandshake() h.resetToSynRcvd(isn, irs, opts, deferAccept) return h } @@ -491,7 +496,7 @@ func (h *handshake) resolveRoute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.LastError() + return h.ep.lastErrorLocked() } } @@ -502,8 +507,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } } -// execute executes the TCP 3-way handshake. -func (h *handshake) execute() *tcpip.Error { +// start resolves the route if necessary and sends the first +// SYN/SYN-ACK. +func (h *handshake) start() *tcpip.Error { if h.ep.route.IsResolutionRequired() { if err := h.resolveRoute(); err != nil { return err @@ -511,19 +517,7 @@ func (h *handshake) execute() *tcpip.Error { } h.startTime = time.Now() - // Initialize the resend timer. - resendWaker := sleep.Waker{} - timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, resendWaker.Assert) - defer rt.Stop() - - // Set up the wakers. - s := sleep.Sleeper{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) - defer s.Done() - + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled @@ -531,10 +525,6 @@ func (h *handshake) execute() *tcpip.Error { sackEnabled = false } - // Send the initial SYN segment and loop until the handshake is - // completed. - h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) - synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, @@ -544,9 +534,8 @@ func (h *handshake) execute() *tcpip.Error { MSS: h.ep.amss, } - // Execute is also called in a listen context so we want to make sure we - // only send the TS/SACK option when we received the TS/SACK in the - // initial SYN. + // start() is also called in a listen context so we want to make sure we only + // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { synOpts.TS = h.ep.sendTSOk synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) @@ -557,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error { } } + h.sendSYNOpts = synOpts h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, @@ -566,7 +556,25 @@ func (h *handshake) execute() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) + return nil +} + +// complete completes the TCP 3-way handshake initiated by h.start(). +func (h *handshake) complete() *tcpip.Error { + // Set up the wakers. + s := sleep.Sleeper{} + resendWaker := sleep.Waker{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + // Initialize the resend timer. + timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + if err != nil { + return err + } + defer timer.stop() for h.state != handshakeCompleted { // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held // throughout handshake processing). @@ -576,11 +584,9 @@ func (h *handshake) execute() *tcpip.Error { switch index { case wakerForResend: - timeOut *= 2 - if timeOut > MaxRTO { - return tcpip.ErrTimeout + if err := timer.reset(); err != nil { + return err } - rt.Reset(timeOut) // Resend the SYN/SYN-ACK only if the following conditions hold. // - It's an active handshake (deferAccept does not apply) // - It's a passive handshake and we have not yet got the final-ACK. @@ -598,7 +604,7 @@ func (h *handshake) execute() *tcpip.Error { seq: h.iss, ack: h.ackNum, rcvWnd: h.rcvWnd, - }, synOpts) + }, h.sendSYNOpts) } case wakerForNotification: @@ -624,9 +630,8 @@ func (h *handshake) execute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.LastError() + return h.ep.lastErrorLocked() } - case wakerForNewSegment: if err := h.processSegments(); err != nil { return err @@ -637,6 +642,34 @@ func (h *handshake) execute() *tcpip.Error { return nil } +type backoffTimer struct { + timeout time.Duration + maxTimeout time.Duration + t *time.Timer +} + +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { + if timeout > maxTimeout { + return nil, tcpip.ErrTimeout + } + bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} + bt.t = time.AfterFunc(timeout, f) + return bt, nil +} + +func (bt *backoffTimer) reset() *tcpip.Error { + bt.timeout *= 2 + if bt.timeout > MaxRTO { + return tcpip.ErrTimeout + } + bt.t.Reset(bt.timeout) + return nil +} + +func (bt *backoffTimer) stop() { + bt.t.Stop() +} + func parseSynSegmentOptions(s *segment) header.TCPSynOptions { synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) if synOpts.TS { @@ -967,7 +1000,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.setEndpointState(StateError) - e.HardError = err + e.hardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk @@ -1106,7 +1139,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // delete the TCB, and return. case StateCloseWait: e.transitionToStateCloseLocked() - e.HardError = tcpip.ErrAborted + e.hardError = tcpip.ErrAborted e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: @@ -1318,7 +1351,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ epilogue := func() { // e.mu is expected to be hold upon entering this section. - if e.snd != nil { e.snd.resendTimer.cleanup() } @@ -1342,20 +1374,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if handshake { - // This is an active connection, so we must initiate the 3-way - // handshake, and then inform potential waiters about its - // completion. - initialRcvWnd := e.initialReceiveWindow() - h := newHandshake(e, seqnum.Size(initialRcvWnd)) - h.ep.setEndpointState(StateSynSent) - - if err := h.execute(); err != nil { + if err := e.h.complete(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() e.setEndpointState(StateError) - e.HardError = err + e.hardError = err e.workerCleanup = true // Lock released below. @@ -1364,9 +1389,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } - e.keepalive.timer.init(&e.keepalive.waker) - defer e.keepalive.timer.cleanup() - drained := e.drainDone != nil if drained { close(e.drainDone) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 23b9de8c5..36b915510 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -315,11 +315,6 @@ func (*Stats) IsEndpointStats() {} // +stateify savable type EndpointInfo struct { stack.TransportEndpointInfo - - // HardError is meaningful only when state is stateError. It stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. HardError is protected by endpoint mu. - HardError *tcpip.Error `state:".(string)"` } // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -386,6 +381,11 @@ type endpoint struct { waiterQueue *waiter.Queue `state:"wait"` uniqueID uint64 + // hardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. hardError is protected by endpoint mu. + hardError *tcpip.Error `state:".(string)"` + // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. lastErrorMu sync.Mutex `state:"nosave"` @@ -440,9 +440,11 @@ type endpoint struct { ttl uint8 v6only bool isConnectNotified bool - // TCP should never broadcast but Linux nevertheless supports enabling/ - // disabling SO_BROADCAST, albeit as a NOOP. - broadcast bool + + // h stores a reference to the current handshake state if the endpoint is in + // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep. + // nil otherwise. + h *handshake `state:"nosave"` // portFlags stores the current values of port related flags. portFlags ports.Flags @@ -685,6 +687,9 @@ type endpoint struct { // linger is used for SO_LINGER socket option. linger tcpip.LingerOption + + // ops is used to get socket level options. + ops tcpip.SocketOptions } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -922,6 +927,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) + e.keepalive.timer.init(&e.keepalive.waker) return e } @@ -1146,6 +1152,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. e.closePendingAcceptableConnectionsLocked() + e.keepalive.timer.cleanup() e.workerCleanup = false @@ -1272,11 +1279,20 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -func (e *endpoint) LastError() *tcpip.Error { +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) hardErrorLocked() *tcpip.Error { + err := e.hardError + e.hardError = nil + return err +} + +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) lastErrorLocked() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1284,6 +1300,16 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// LastError implements tcpip.Endpoint.LastError. +func (e *endpoint) LastError() *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + if err := e.hardErrorLocked(); err != nil { + return err + } + return e.lastErrorLocked() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1305,9 +1331,11 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.HardError if s == StateError { - return buffer.View{}, tcpip.ControlMessages{}, he + if err := e.hardErrorLocked(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } e.stats.ReadErrors.NotConnected.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected @@ -1363,9 +1391,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { + // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: - return 0, e.HardError + if err := e.hardErrorLocked(); err != nil { + return 0, err + } + return 0, tcpip.ErrClosedForSend case !s.connecting() && !s.connected(): return 0, tcpip.ErrClosedForSend case s.connecting(): @@ -1479,7 +1511,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.HardError + return 0, tcpip.ControlMessages{}, e.hardErrorLocked() } e.stats.ReadErrors.InvalidEndpointState.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -1599,11 +1631,6 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { - case tcpip.BroadcastOption: - e.LockUser() - e.broadcast = v - e.UnlockUser() - case tcpip.CorkOption: e.LockUser() if !v { @@ -1950,11 +1977,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.BroadcastOption: - e.LockUser() - v := e.broadcast - e.UnlockUser() - return v, nil case tcpip.CorkOption: return atomic.LoadUint32(&e.cork) != 0, nil @@ -2185,6 +2207,8 @@ func (*endpoint) Disconnect() *tcpip.Error { func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { err := e.connect(addr, true, true) if err != nil && !err.IgnoreStats() { + // Connect failed. Let's wake up any waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() } @@ -2244,7 +2268,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnecting case StateError: - return e.HardError + if err := e.hardErrorLocked(); err != nil { + return err + } + return tcpip.ErrConnectionAborted default: return tcpip.ErrInvalidEndpointState @@ -2397,14 +2424,70 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if run { - e.workerRunning = true - e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + if err := e.startMainLoop(handshake); err != nil { + return err + } } return tcpip.ErrConnectStarted } +// startMainLoop sends the initial SYN and starts the main loop for the +// endpoint. +func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { + preloop := func() *tcpip.Error { + if handshake { + h := e.newHandshake() + e.setEndpointState(StateSynSent) + if err := h.start(); err != nil { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + e.setEndpointState(StateError) + e.hardError = err + + // Call cleanupLocked to free up any reservations. + e.cleanupLocked() + return err + } + } + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() + return nil + } + + if e.route.IsResolutionRequired() { + // If the endpoint is closed between releasing e.mu and the goroutine below + // acquiring it, make sure that cleanup is deferred to the new goroutine. + e.workerRunning = true + + // Sending the initial SYN may block due to route resolution; do it in a + // separate goroutine to avoid blocking the syscall goroutine. + go func() { // S/R-SAFE: will be drained before save. + e.mu.Lock() + if err := preloop(); err != nil { + e.workerRunning = false + e.mu.Unlock() + return + } + e.mu.Unlock() + _ = e.protocolMainLoop(handshake, nil) + }() + return nil + } + + // No route resolution is required, so we can send the initial SYN here without + // blocking. This will hopefully reduce overall latency by overlapping time + // spent waiting for a SYN-ACK and time spent spinning up a new goroutine + // for the main loop. + if err := preloop(); err != nil { + return err + } + e.workerRunning = true + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + return nil +} + // ConnectEndpoint is not supported. func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { return tcpip.ErrInvalidEndpointState @@ -3061,6 +3144,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { FACK: rc.fack, RTT: rc.rtt, Reord: rc.reorderSeen, + DSACKSeen: rc.dsackSeen, } return s } @@ -3130,3 +3214,8 @@ func (e *endpoint) Wait() { <-notifyCh } } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 2bcc5e1c2..ba67176b5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() { // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) + e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -320,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { } // saveHardError is invoked by stateify. -func (e *EndpointInfo) saveHardError() string { - if e.HardError == nil { +func (e *endpoint) saveHardError() string { + if e.hardError == nil { return "" } - return e.HardError.String() + return e.hardError.String() } // loadHardError is invoked by stateify. -func (e *EndpointInfo) loadHardError(s string) { +func (e *endpoint) loadHardError(s string) { if s == "" { return } - e.HardError = tcpip.StringToError(s) + e.hardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 0664789da..596178625 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } f := r.forwarder - ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ + ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{ MSS: r.synOptions.MSS, WS: r.synOptions.WS, TS: r.synOptions.TS, diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index d312b1b8b..e0a50a919 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -29,12 +29,12 @@ import ( // // +stateify savable type rackControl struct { + // dsackSeen indicates if the connection has seen a DSACK. + dsackSeen bool + // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value - // dsack indicates if the connection has seen a DSACK. - dsack bool - // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value @@ -122,3 +122,8 @@ func (rc *rackControl) detectReorder(seg *segment) { rc.reorderSeen = true } } + +// setDSACKSeen updates rack control if duplicate SACK is seen by the connection. +func (rc *rackControl) setDSACKSeen() { + rc.dsackSeen = true +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index ab5fa4fb7..0e0fdf14c 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1285,25 +1285,29 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { - if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 { + // Look for DSACK block. + idx := 0 + n := len(rcvdSeg.parsedOptions.SACKBlocks) + if s.checkDSACK(rcvdSeg) { + s.rc.setDSACKSeen() + idx = 1 + n-- + } + + if n == 0 { return } // Sort the SACK blocks. The first block is the most recent unacked // block. The following blocks can be in arbitrary order. - sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks)) - copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks) + sackBlocks := make([]header.SACKBlock, n) + copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks[idx:]) sort.Slice(sackBlocks, func(i, j int) bool { return sackBlocks[j].Start.LessThan(sackBlocks[i].Start) }) seg := s.writeList.Front() for _, sb := range sackBlocks { - // This check excludes DSACK blocks. - if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) { - continue - } - for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { s.rc.update(seg, rcvdSeg, s.ep.tsOffset) @@ -1315,6 +1319,50 @@ func (s *sender) walkSACK(rcvdSeg *segment) { } } +// checkDSACK checks if a DSACK is reported and updates it in RACK. +func (s *sender) checkDSACK(rcvdSeg *segment) bool { + n := len(rcvdSeg.parsedOptions.SACKBlocks) + if n == 0 { + return false + } + + sb := rcvdSeg.parsedOptions.SACKBlocks[0] + // Check if SACK block is invalid. + if sb.End.LessThan(sb.Start) { + return false + } + + // See: https://tools.ietf.org/html/rfc2883#section-5 DSACK is sent in + // at most one SACK block. DSACK is detected in the below two cases: + // * If the SACK sequence space is less than this cumulative ACK, it is + // an indication that the segment identified by the SACK block has + // been received more than once by the receiver. + // * If the sequence space in the first SACK block is greater than the + // cumulative ACK, then the sender next compares the sequence space + // in the first SACK block with the sequence space in the second SACK + // block, if there is one. This comparison can determine if the first + // SACK block is reporting duplicate data that lies above the + // cumulative ACK. + if sb.Start.LessThan(rcvdSeg.ackNumber) { + return true + } + + if n > 1 { + sb1 := rcvdSeg.parsedOptions.SACKBlocks[1] + if sb1.End.LessThan(sb1.Start) { + return false + } + + // If the first SACK block is fully covered by second SACK + // block, then the first block is a DSACK block. + if sb.End.LessThanEq(sb1.End) && sb1.Start.LessThanEq(sb.Start) { + return true + } + } + + return false +} + // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index d3f92b48c..9818ffa0f 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -30,15 +30,17 @@ const ( maxPayload = 10 tsOptionSize = 12 maxTCPOptionSize = 40 + mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload ) // TestRACKUpdate tests the RACK related fields are updated when an ACK is // received on a SACK enabled connection. func TestRACKUpdate(t *testing.T) { - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + c := context.New(t, uint32(mtu)) defer c.Cleanup() var xmitTime time.Time + probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { // Validate that the endpoint Sender.RACKState is what we expect. if state.Sender.RACKState.XmitTime.Before(xmitTime) { @@ -54,6 +56,7 @@ func TestRACKUpdate(t *testing.T) { if state.Sender.RACKState.RTT == 0 { t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") } + close(probeDone) }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) @@ -73,18 +76,20 @@ func TestRACKUpdate(t *testing.T) { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - time.Sleep(200 * time.Millisecond) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone } // TestRACKDetectReorder tests that RACK detects packet reordering. func TestRACKDetectReorder(t *testing.T) { - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + c := context.New(t, uint32(mtu)) defer c.Cleanup() - const ackNum = 2 - var n int - ch := make(chan struct{}) + const ackNumToVerify = 2 + probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { gotSeq := state.Sender.RACKState.FACK wantSeq := state.Sender.SndNxt @@ -95,7 +100,7 @@ func TestRACKDetectReorder(t *testing.T) { } n++ - if n < ackNum { + if n < ackNumToVerify { if state.Sender.RACKState.Reord { t.Fatalf("RACK reorder detected when there is no reordering") } @@ -105,11 +110,11 @@ func TestRACKDetectReorder(t *testing.T) { if state.Sender.RACKState.Reord == false { t.Fatalf("RACK reorder detection failed") } - close(ch) + close(probeDone) }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(ackNum * maxPayload) + data := buffer.NewView(ackNumToVerify * maxPayload) for i := range data { data[i] = byte(i) } @@ -120,7 +125,7 @@ func TestRACKDetectReorder(t *testing.T) { } bytesRead := 0 - for i := 0; i < ackNum; i++ { + for i := 0; i < ackNumToVerify; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload } @@ -133,5 +138,393 @@ func TestRACKDetectReorder(t *testing.T) { // Wait for the probe function to finish processing the ACK before the // test completes. - <-ch + <-probeDone +} + +func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View { + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + data := buffer.NewView(numPackets * maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + for i := 0; i < numPackets; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + return data +} + +const ( + validDSACKDetected = 1 + failedToDetectDSACK = 2 + invalidDSACKDetected = 3 +) + +func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) { + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK detects DSACK. + n++ + if n < numACK { + if state.Sender.RACKState.DSACKSeen { + probeDone <- invalidDSACKDetected + } + return + } + + if !state.Sender.RACKState.DSACKSeen { + probeDone <- failedToDetectDSACK + return + } + probeDone <- validDSACKDetected + }) +} + +// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.1. +func TestRACKDetectDSACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 2 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 8 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of +// order segments. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.2. +func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 2 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 10 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + // Send DSACK block for #6 along with out of + // order #9 packet is received. + start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload) + end1 := start1.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a +// duplicate of out of order packet. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.3 +func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 4 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 10 + sendAndReceive(t, c, numPackets) + + // ACK [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Send SACK indicating #6 packet is missing and received #7 packet. + offset := seqnum.Size(bytesRead + maxPayload) + start := c.IRS.Add(1 + offset) + end := start.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Send SACK with #6 packet is missing and received [7-8] packets. + end = start.Add(2 * maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Consider #8 packet is duplicated on the network and send DSACK. + dsackStart := c.IRS.Add(1 + offset + maxPayload) + dsackEnd := dsackStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.1. +func TestRACKDetectDSACKSingleDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 4 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 4 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-3] packets and received #4 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // ACK for retransmitted #2 packet. + bytesRead += maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by + // sending DSACK block for the subsegment. + dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous +// duplicate subsegments covered by the cumulative acknowledgement. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.2. +func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 5 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 6 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-5] packets and received #6 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(5*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Received delayed #2 packet. + bytesRead += maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed #4 packet. + start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end1 := start1.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Simulate receiving retransmitted subsegment for #2 packet and delayed #3 + // packet by sending DSACK block for #2 packet. + dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not +// covered by the cumulative acknowledgement. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.3. +func TestRACKDetectDSACKDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 5 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 7 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed #3 packet. + start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + end1 := start1.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Consider #2 packet has been dropped and SACK #4 packet. + start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end2 := start2.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}}) + + // Simulate receiving retransmitted subsegment for #3 packet and delayed #5 + // packet by sending DSACK block for the subsegment. + dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + end1 = end1.Add(seqnum.Size(2 * maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK +// is not the first SACK block. +func TestRACKWithInvalidDSACKBlock(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + const ackNumToVerify = 2 + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK does not detect DSACK when DSACK block is + // not the first SACK block. + n++ + t.Helper() + if state.Sender.RACKState.DSACKSeen { + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } + + if n == ackNumToVerify { + close(probeDone) + } + }) + + numPackets := 10 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + + // Send DSACK block as second block. + start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload) + end1 := start1.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index fcc3c5000..c366a4cbc 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,9 +75,6 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh - if err := ep.LastError(); err != tcpip.ErrAborted { - t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted) - } // Call Connect again to retreive the handshake failure status // and stats updates. @@ -3198,6 +3195,11 @@ loop: case tcpip.ErrWouldBlock: select { case <-ch: + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + } + break loop case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for reset to arrive") } @@ -3207,14 +3209,10 @@ loop: t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) } } - // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) - } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) } @@ -5717,6 +5715,50 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } +func TestSYNRetransmit(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send the same SYN packet multiple times. We should still get a valid SYN-ACK + // reply. + irs := seqnum.Value(789) + for i := 0; i < 5; i++ { + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + } + + // Receive the SYN-ACK reply. + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs) + 1), + } + checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...)) +} + func TestSynRcvdBadSeqNumber(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 7981d469b..38a335840 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) { // cleanup frees all resources associated with the timer. func (t *timer) cleanup() { + if t.timer == nil { + // No cleanup needed. + return + } t.timer.Stop() *t = timer{} } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c78549424..7ebae63d8 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -56,6 +56,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9bcb918bb..81601f559 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -108,7 +108,6 @@ type endpoint struct { multicastLoop bool portFlags ports.Flags bindToDevice tcpip.NICID - broadcast bool noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` @@ -157,6 +156,9 @@ type endpoint struct { // linger is used for SO_LINGER socket option. linger tcpip.LingerOption + + // ops is used to get socket level options. + ops tcpip.SocketOptions } // +stateify savable @@ -427,7 +429,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c to := opts.To e.mu.RLock() - defer e.mu.RUnlock() + lockReleased := false + defer func() { + if lockReleased { + return + } + e.mu.RUnlock() + }() // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { @@ -473,7 +481,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.state != StateConnected { err = tcpip.ErrInvalidEndpointState } - return + return ch, err } } else { // Reject destination address if it goes through a different @@ -508,7 +516,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c resolve = route.Resolve } - if !e.broadcast && route.IsOutboundBroadcast() { + if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { return 0, nil, tcpip.ErrBroadcastDisabled } @@ -539,7 +547,24 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { + localPort := e.ID.LocalPort + sendTOS := e.sendTOS + owner := e.owner + noChecksum := e.noChecksum + lockReleased = true + e.mu.RUnlock() + + // Do not hold lock when sending as loopback is synchronous and if the UDP + // datagram ends up generating an ICMP response then it can result in a + // deadlock where the ICMP response handling ends up acquiring this endpoint's + // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a + // deadlock if another caller is trying to acquire e.mu in exclusive mode w/ + // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the + // lock can be eventually acquired. + // + // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read + // locking is prohibited. + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -553,11 +578,6 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v - e.mu.Unlock() - case tcpip.MulticastLoopOption: e.mu.Lock() e.multicastLoop = v @@ -614,7 +634,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.v6only = v } - return nil } @@ -830,12 +849,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.BroadcastOption: - e.mu.RLock() - v := e.broadcast - e.mu.RUnlock() - return v, nil - case tcpip.KeepaliveEnabledOption: return false, nil @@ -1522,6 +1535,12 @@ func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c09c7aa86..492e277a8 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,6 +55,7 @@ const ( stackPort = 1234 testAddr = "\x0a\x00\x00\x02" testPort = 4096 + invalidPort = 8192 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast @@ -295,7 +297,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: true, }) } @@ -364,9 +367,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.t.Fatalf("SetSockOptBool failed: %s", err) } } else if flow.isBroadcast() { - if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - c.t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetBroadcast(true) } } @@ -974,7 +975,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { // provided. func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() - return testWriteInternal(c, flow, true, checkers...) + return testWriteAndVerifyInternal(c, flow, true, checkers...) } // testWriteWithoutDestination sends a packet of the given test flow from the @@ -983,10 +984,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker // checker functions provided. func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() - return testWriteInternal(c, flow, false, checkers...) + return testWriteAndVerifyInternal(c, flow, false, checkers...) } -func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { +func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View { c.t.Helper() // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() @@ -1008,6 +1009,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } c.checkEndpointWriteStats(1, epstats, err) + return payload +} + +func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + payload := testWriteNoVerify(c, flow, setDest) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -1152,6 +1159,39 @@ func TestV4WriteOnConnected(t *testing.T) { testWriteWithoutDestination(c, unicastV4) } +func TestWriteOnConnectedInvalidPort(t *testing.T) { + protocols := map[string]tcpip.NetworkProtocolNumber{ + "ipv4": ipv4.ProtocolNumber, + "ipv6": ipv6.ProtocolNumber, + } + for name, pn := range protocols { + t.Run(name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(pn) + if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, + } + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + if err != nil { + c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) + } + if got, want := n, int64(len(payload)); got != want { + c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) + } + + if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused { + c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + } + }) + } +} + // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket // that is bound to a V4 multicast address. func TestWriteOnBoundToV4Multicast(t *testing.T) { @@ -1451,6 +1491,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -2393,17 +2437,13 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) } - if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err) - } + ep.SocketOptions().SetBroadcast(true) if n, _, err := ep.Write(data, opts); err != nil { t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) } - if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil { - t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err) - } + ep.SocketOptions().SetBroadcast(false) if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index 49ab87c58..976331230 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -36,7 +36,6 @@ import ( "path/filepath" "strconv" "strings" - "sync/atomic" "syscall" "testing" "time" @@ -417,33 +416,35 @@ func StartReaper() func() { // WaitUntilRead reads from the given reader until the wanted string is found // or until timeout. -func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time.Duration) error { +func WaitUntilRead(r io.Reader, want string, timeout time.Duration) error { sc := bufio.NewScanner(r) - if split != nil { - sc.Split(split) - } // done must be accessed atomically. A value greater than 0 indicates // that the read loop can exit. - var done uint32 - doneCh := make(chan struct{}) + doneCh := make(chan bool) + defer close(doneCh) go func() { for sc.Scan() { t := sc.Text() if strings.Contains(t, want) { - atomic.StoreUint32(&done, 1) - close(doneCh) - break + doneCh <- true + return } - if atomic.LoadUint32(&done) > 0 { - break + select { + case <-doneCh: + return + default: } } + doneCh <- false }() + select { case <-time.After(timeout): - atomic.StoreUint32(&done, 1) return fmt.Errorf("timeout waiting to read %q", want) - case <-doneCh: + case res := <-doneCh: + if !res { + return fmt.Errorf("reader closed while waiting to read %q", want) + } return nil } } |