diff options
Diffstat (limited to 'pkg/sentry/fs/proc')
-rw-r--r-- | pkg/sentry/fs/proc/sys_net.go | 42 | ||||
-rw-r--r-- | pkg/sentry/fs/proc/sys_net_state.go | 19 | ||||
-rw-r--r-- | pkg/sentry/fs/proc/sys_net_test.go | 13 | ||||
-rw-r--r-- | pkg/sentry/fs/proc/task.go | 2 |
4 files changed, 39 insertions, 37 deletions
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index f2f49a7f6..e555672ad 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -55,7 +55,7 @@ type tcpMemInode struct { // size stores the tcp buffer size during save, and sets the buffer // size in netstack in restore. We must save/restore this here, since - // netstack itself is stateless. + // a netstack instance is created on restore. size inet.TCPBufferSize // mu protects against concurrent reads/writes to files based on this @@ -259,6 +259,9 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque if src.NumBytes() == 0 { return 0, nil } + + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. src = src.TakeFirst(usermem.PageSize - 1) var v int32 @@ -390,21 +393,14 @@ func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.S // // +stateify savable type ipForwarding struct { - stack inet.Stack `state:".(ipForwardingState)"` fsutil.SimpleFileInode -} -// ipForwardingState is used to stores a state of netstack -// for packet forwarding because netstack itself is stateless. -// -// +stateify savable -type ipForwardingState struct { stack inet.Stack `state:"wait"` - // enabled stores packet forwarding settings during save, and sets it back - // in netstack in restore. We must save/restore this here, since - // netstack itself is stateless. - enabled bool + // enabled stores the IPv4 forwarding state on save. + // We must save/restore this here, since a netstack instance + // is created on restore. + enabled *bool } func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { @@ -441,6 +437,8 @@ type ipForwardingFile struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` waiter.AlwaysReady `state:"nosave"` + ipf *ipForwarding + stack inet.Stack `state:"wait"` } @@ -450,6 +448,7 @@ func (ipf *ipForwarding) GetFile(ctx context.Context, dirent *fs.Dirent, flags f flags.Pwrite = true return fs.NewFile(ctx, dirent, flags, &ipForwardingFile{ stack: ipf.stack, + ipf: ipf, }), nil } @@ -459,14 +458,18 @@ func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return 0, io.EOF } + if f.ipf.enabled == nil { + enabled := f.stack.Forwarding(ipv4.ProtocolNumber) + f.ipf.enabled = &enabled + } + val := "0\n" - if f.stack.Forwarding(ipv4.ProtocolNumber) { + if *f.ipf.enabled { // Technically, this is not quite compatible with Linux. Linux // stores these as an integer, so if you write "2" into // ip_forward, you should get 2 back. val = "1\n" } - n, err := dst.CopyOut(ctx, []byte(val)) return int64(n), err } @@ -479,7 +482,8 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO return 0, nil } - // Only consider size of one memory page for input. + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. src = src.TakeFirst(usermem.PageSize - 1) var v int32 @@ -487,9 +491,11 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO if err != nil { return n, err } - - enabled := v != 0 - return n, f.stack.SetForwarding(ipv4.ProtocolNumber, enabled) + if f.ipf.enabled == nil { + f.ipf.enabled = new(bool) + } + *f.ipf.enabled = v != 0 + return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) } func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go index 3fadb870e..4cb4741af 100644 --- a/pkg/sentry/fs/proc/sys_net_state.go +++ b/pkg/sentry/fs/proc/sys_net_state.go @@ -45,18 +45,11 @@ func (s *tcpSack) afterLoad() { } } -// saveStack is invoked by stateify. -func (ipf *ipForwarding) saveStack() ipForwardingState { - return ipForwardingState{ - ipf.stack, - ipf.stack.Forwarding(ipv4.ProtocolNumber), - } -} - -// loadStack is invoked by stateify. -func (ipf *ipForwarding) loadStack(s ipForwardingState) { - ipf.stack = s.stack - if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, s.enabled); err != nil { - panic(fmt.Sprintf("failed to set previous IPv4 forwarding configuration [%v]: %v", s.enabled, err)) +// afterLoad is invoked by stateify. +func (ipf *ipForwarding) afterLoad() { + if ipf.enabled != nil { + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err)) + } } } diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go index 72c9857d0..6ef5738e7 100644 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ b/pkg/sentry/fs/proc/sys_net_test.go @@ -176,18 +176,21 @@ func TestIPForwarding(t *testing.T) { for _, c := range cases { t.Run(c.comment, func(t *testing.T) { s.IPForwarding = c.initial - - file := &ipForwardingFile{stack: s} + ipf := &ipForwarding{stack: s} + file := &ipForwardingFile{ + stack: s, + ipf: ipf, + } // Write the values. src := usermem.BytesIOSequence([]byte(c.str)) if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("file.Write(ctx, nil, %v, 0) = (%d, %v); wanted (%d, nil)", c.str, n, err, len(c.str)) + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) } // Read the values from the stack and check them. - if s.IPForwarding != c.final { - t.Errorf("s.IPForwarding = %v; wanted %v", s.IPForwarding, c.final) + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) } }) diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 9cf7f2a62..103bfc600 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -604,7 +604,7 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( var vss, rss, data uint64 s.t.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() + fds = fdTable.CurrentMaxFDs() } if mm := t.MemoryManager(); mm != nil { vss = mm.VirtualMemorySize() |