diff options
author | gVisor bot <gvisor-bot@google.com> | 2020-09-20 18:17:20 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-09-20 18:17:20 -0700 |
commit | ca308747205020c957d7fea3929f6c26004a6dd3 (patch) | |
tree | 1d6b2c0f7306c3a740a2e039e6b448972e64f00e /pkg | |
parent | 916751039cca927a0e64b4e6f776d2d4732cf8d8 (diff) | |
parent | ac324f646ee3cb7955b0b45a7453aeb9671cbdf1 (diff) |
Merge pull request #3651 from ianlewis:ip-forwarding
PiperOrigin-RevId: 332760843
Diffstat (limited to 'pkg')
24 files changed, 476 insertions, 52 deletions
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 77c2c5c0e..b8b2281a8 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/tcpip/network/ipv4", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 8615b60f0..e555672ad 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,7 +55,7 @@ type tcpMemInode struct { // size stores the tcp buffer size during save, and sets the buffer // size in netstack in restore. We must save/restore this here, since - // netstack itself is stateless. + // a netstack instance is created on restore. size inet.TCPBufferSize // mu protects against concurrent reads/writes to files based on this @@ -258,6 +259,9 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque if src.NumBytes() == 0 { return 0, nil } + + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. src = src.TakeFirst(usermem.PageSize - 1) var v int32 @@ -383,11 +387,125 @@ func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.S return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil) } +// ipForwarding implements fs.InodeOperations. +// +// ipForwarding is used to enable/disable packet forwarding of netstack. +// +// +stateify savable +type ipForwarding struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + + // enabled stores the IPv4 forwarding state on save. + // We must save/restore this here, since a netstack instance + // is created on restore. + enabled *bool +} + +func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ipf := &ipForwarding{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ipf, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. Truncate is called when +// O_TRUNC is specified for any kind of existing Dirent but is not called via +// (f)truncate for proc files. +func (*ipForwarding) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// +stateify savable +type ipForwardingFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + ipf *ipForwarding + + stack inet.Stack `state:"wait"` +} + +// GetFile implements fs.InodeOperations.GetFile. +func (ipf *ipForwarding) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &ipForwardingFile{ + stack: ipf.stack, + ipf: ipf, + }), nil +} + +// Read implements fs.FileOperations.Read. +func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + if f.ipf.enabled == nil { + enabled := f.stack.Forwarding(ipv4.ProtocolNumber) + f.ipf.enabled = &enabled + } + + val := "0\n" + if *f.ipf.enabled { + // Technically, this is not quite compatible with Linux. Linux + // stores these as an integer, so if you write "2" into + // ip_forward, you should get 2 back. + val = "1\n" + } + n, err := dst.CopyOut(ctx, []byte(val)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +// +// Offset is ignored, multiple writes are not supported. +func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return n, err + } + if f.ipf.enabled == nil { + f.ipf.enabled = new(bool) + } + *f.ipf.enabled = v != 0 + return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) +} + func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { contents := map[string]*fs.Inode{ // Add tcp_sack. "tcp_sack": newTCPSackInode(ctx, msrc, s), + // Add ip_forward. + "ip_forward": newIPForwardingInode(ctx, msrc, s), + // The following files are simple stubs until they are // implemented in netstack, most of these files are // configuration related. We use the value closest to the diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go index 6eba709c6..4cb4741af 100644 --- a/pkg/sentry/fs/proc/sys_net_state.go +++ b/pkg/sentry/fs/proc/sys_net_state.go @@ -14,7 +14,11 @@ package proc -import "fmt" +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" +) // beforeSave is invoked by stateify. func (t *tcpMemInode) beforeSave() { @@ -40,3 +44,12 @@ func (s *tcpSack) afterLoad() { } } } + +// afterLoad is invoked by stateify. +func (ipf *ipForwarding) afterLoad() { + if ipf.enabled != nil { + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err)) + } + } +} diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go index 355e83d47..6ef5738e7 100644 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ b/pkg/sentry/fs/proc/sys_net_test.go @@ -123,3 +123,76 @@ func TestConfigureRecvBufferSize(t *testing.T) { } } } + +// TestIPForwarding tests the implementation of +// /proc/sys/net/ipv4/ip_forwarding +func TestIPForwarding(t *testing.T) { + ctx := context.Background() + s := inet.NewTestStack() + + var cases = []struct { + comment string + initial bool + str string + final bool + }{ + { + comment: `Forwarding is disabled; write 1 and enable forwarding`, + initial: false, + str: "1", + final: true, + }, + { + comment: `Forwarding is disabled; write 0 and disable forwarding`, + initial: false, + str: "0", + final: false, + }, + { + comment: `Forwarding is enabled; write 1 and enable forwarding`, + initial: true, + str: "1", + final: true, + }, + { + comment: `Forwarding is enabled; write 0 and disable forwarding`, + initial: true, + str: "0", + final: false, + }, + { + comment: `Forwarding is disabled; write 2404 and enable forwarding`, + initial: false, + str: "2404", + final: true, + }, + { + comment: `Forwarding is enabled; write 2404 and enable forwarding`, + initial: true, + str: "2404", + final: true, + }, + } + for _, c := range cases { + t.Run(c.comment, func(t *testing.T) { + s.IPForwarding = c.initial + ipf := &ipForwarding{stack: s} + file := &ipForwardingFile{ + stack: s, + ipf: ipf, + } + + // Write the values. + src := usermem.BytesIOSequence([]byte(c.str)) + if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) + } + + // Read the values from the stack and check them. + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) + } + + }) + } +} diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index a45b44440..2e086e34c 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -100,6 +100,7 @@ go_library( "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/tcpip/network/ipv4", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 9e0966efe..a3ffbb15e 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" ) @@ -67,6 +68,7 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke "tcp_rmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpRMem}), "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), "tcp_wmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newDentry(root, fs.NextIno(), 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -354,3 +356,63 @@ func (d *tcpMemData) writeSizeLocked(size inet.TCPBufferSize) error { panic(fmt.Sprintf("unknown tcpMemFile type: %v", d.dir)) } } + +// ipForwarding implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/ip_forwarding. +// +// +stateify savable +type ipForwarding struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` + enabled *bool +} + +var _ vfs.WritableDynamicBytesSource = (*ipForwarding)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error { + if ipf.enabled == nil { + enabled := ipf.stack.Forwarding(ipv4.ProtocolNumber) + ipf.enabled = &enabled + } + + val := "0\n" + if *ipf.enabled { + // Technically, this is not quite compatible with Linux. Linux stores these + // as an integer, so if you write "2" into tcp_sack, you should get 2 back. + // Tough luck. + val = "1\n" + } + buf.WriteString(val) + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is large. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + if ipf.enabled == nil { + ipf.enabled = new(bool) + } + *ipf.enabled = v != 0 + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + return 0, err + } + return n, nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go index be54897bb..6cee22823 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go @@ -20,8 +20,10 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/usermem" ) func newIPv6TestStack() *inet.TestStack { @@ -76,3 +78,72 @@ func TestIfinet6(t *testing.T) { t.Errorf("Got n.contents() = %v, want = %v", got, want) } } + +// TestIPForwarding tests the implementation of +// /proc/sys/net/ipv4/ip_forwarding +func TestConfigureIPForwarding(t *testing.T) { + ctx := context.Background() + s := inet.NewTestStack() + + var cases = []struct { + comment string + initial bool + str string + final bool + }{ + { + comment: `Forwarding is disabled; write 1 and enable forwarding`, + initial: false, + str: "1", + final: true, + }, + { + comment: `Forwarding is disabled; write 0 and disable forwarding`, + initial: false, + str: "0", + final: false, + }, + { + comment: `Forwarding is enabled; write 1 and enable forwarding`, + initial: true, + str: "1", + final: true, + }, + { + comment: `Forwarding is enabled; write 0 and disable forwarding`, + initial: true, + str: "0", + final: false, + }, + { + comment: `Forwarding is disabled; write 2404 and enable forwarding`, + initial: false, + str: "2404", + final: true, + }, + { + comment: `Forwarding is enabled; write 2404 and enable forwarding`, + initial: true, + str: "2404", + final: true, + }, + } + for _, c := range cases { + t.Run(c.comment, func(t *testing.T) { + s.IPForwarding = c.initial + + file := &ipForwarding{stack: s, enabled: &c.initial} + + // Write the values. + src := usermem.BytesIOSequence([]byte(c.str)) + if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil { + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) + } + + // Read the values from the stack and check them. + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) + } + }) + } +} diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 07bf39fed..5bba9de0b 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -15,6 +15,7 @@ go_library( ], deps = [ "//pkg/context", + "//pkg/tcpip", "//pkg/tcpip/stack", ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index c0b4831d1..fbe6d6aa6 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,7 +15,10 @@ // Package inet defines semantics for IP stacks. package inet -import "gvisor.dev/gvisor/pkg/tcpip/stack" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) // Stack represents a TCP/IP stack. type Stack interface { @@ -80,6 +83,12 @@ type Stack interface { // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful // for restoring a stack after a save. RestoreCleanupEndpoints([]stack.TransportEndpoint) + + // Forwarding returns if packet forwarding between NICs is enabled. + Forwarding(protocol tcpip.NetworkProtocolNumber) bool + + // SetForwarding enables or disables packet forwarding between NICs. + SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 9771f01fc..1779cc6f3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,7 +14,10 @@ package inet -import "gvisor.dev/gvisor/pkg/tcpip/stack" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) // TestStack is a dummy implementation of Stack for tests. type TestStack struct { @@ -26,6 +29,7 @@ type TestStack struct { TCPSendBufSize TCPBufferSize TCPSACKFlag bool Recovery TCPLossRecovery + IPForwarding bool } // NewTestStack returns a TestStack with no network interfaces. The value of @@ -128,3 +132,14 @@ func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} + +// Forwarding implements inet.Stack.Forwarding. +func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + return s.IPForwarding +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + s.IPForwarding = enable + return nil +} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 632e33452..b6ebe29d6 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -39,6 +39,9 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index fda3dcb35..faa61160e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -30,6 +30,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -59,6 +62,8 @@ type Stack struct { tcpSACKEnabled bool netDevFile *os.File netSNMPFile *os.File + ipv4Forwarding bool + ipv6Forwarding bool } // NewStack returns an empty Stack containing no configuration. @@ -118,6 +123,13 @@ func (s *Stack) Configure() error { s.netSNMPFile = f } + s.ipv6Forwarding = false + if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil { + s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" + } else { + log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false") + } + return nil } @@ -468,3 +480,21 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + switch protocol { + case ipv4.ProtocolNumber: + return s.ipv4Forwarding + case ipv6.ProtocolNumber: + return s.ipv6Forwarding + default: + log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol) + return false + } +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + return syserror.EACCES +} diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 36144e1eb..1028d2a6e 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -412,3 +412,24 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { s.Stack.RestoreCleanupEndpoints(es) } + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + switch protocol { + case ipv4.ProtocolNumber, ipv6.ProtocolNumber: + return s.Stack.Forwarding(protocol) + default: + panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol)) + } +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + switch protocol { + case ipv4.ProtocolNumber, ipv6.ProtocolNumber: + s.Stack.SetForwarding(protocol, enable) + default: + panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) + } + return nil +} diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 563bc78ea..c326fab54 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -14,6 +14,8 @@ go_library( go_test( name = "buffer_test", size = "small", - srcs = ["view_test.go"], + srcs = [ + "view_test.go", + ], library = ":buffer", ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2b83c421e..7430b8fcd 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -477,7 +477,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme stack := r.Stack() // Is the networking stack operating as a router? - if !stack.Forwarding() { + if !stack.Forwarding(ProtocolNumber) { // ... No, silently drop the packet. received.RouterOnlyPacketsDroppedByHost.Increment() return diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8112ed051..0f50bfb8e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -728,7 +728,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { }) if isRouter { // Enabling forwarding makes the stack act as a router. - s.SetForwarding(true) + s.SetForwarding(ProtocolNumber, true) } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 480c495fa..7434df4a1 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -958,7 +958,7 @@ func TestNDPValidation(t *testing.T) { if isRouter { // Enabling forwarding makes the stack act as a router. - s.SetForwarding(true) + s.SetForwarding(ProtocolNumber, true) } stats := s.Stats().ICMP.V6PacketsReceived diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 54759091a..38c5bac71 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -316,7 +316,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } // Enable forwarding. - s.SetForwarding(true) + s.SetForwarding(proto.Number(), true) // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index b0873d1af..97ca00d16 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -817,7 +817,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a stack-wide configuration, so we check // stack's forwarding flag to determine if the NIC is a routing // interface. - if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) { return } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 67dc5364f..5e43a9b0b 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1120,7 +1120,7 @@ func TestNoRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1365,7 +1365,7 @@ func TestNoPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1723,7 +1723,7 @@ func TestNoAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -4640,7 +4640,7 @@ func TestCleanupNDPState(t *testing.T) { name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, keepAutoGenLinkLocal: true, maxAutoGenAddrEvents: 4, @@ -5286,11 +5286,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(false) + s.SetForwarding(ipv6.ProtocolNumber, false) }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, }, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 204bfc433..be274773c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -337,7 +337,7 @@ func (n *NIC) enable() *tcpip.Error { // does. That is, routers do not learn from RAs (e.g. on-link prefixes // and default routers). Therefore, soliciting RAs from other routers on // a link is unnecessary for routers. - if !n.stack.forwarding { + if !n.stack.Forwarding(header.IPv6ProtocolNumber) { n.mu.ndp.startSolicitingRouters() } @@ -1303,7 +1303,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() @@ -1330,6 +1330,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. // TODO(b/128629022): move this logic to route.WritePacket. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6a683545d..68cf77de2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -405,6 +405,13 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // forwarding contains the whether packet forwarding is enabled or not for + // different network protocols. + forwarding struct { + sync.RWMutex + protocols map[tcpip.NetworkProtocolNumber]bool + } + // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. rawFactory RawFactory @@ -415,9 +422,8 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -749,6 +755,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } + s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool) // Add specified network protocols. for _, netProto := range opts.NetworkProtocols { @@ -866,46 +873,42 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -// -// When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. -// When forwarding becomes disabled and if IPv6 is enabled, NDP Router -// Solicitations will be stopped. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - defer s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) { + s.forwarding.Lock() + defer s.forwarding.Unlock() - // If forwarding status didn't change, do nothing further. - if s.forwarding == enable { + // If this stack does not support the protocol, do nothing. + if _, ok := s.networkProtocols[protocol]; !ok { return } - s.forwarding = enable - - // If this stack does not support IPv6, do nothing further. - if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { + // If the forwarding value for this protocol hasn't changed then do + // nothing. + if forwarding := s.forwarding.protocols[protocol]; forwarding == enable { return } - if enable { - for _, nic := range s.nics { - nic.becomeIPv6Router() - } - } else { - for _, nic := range s.nics { - nic.becomeIPv6Host() + s.forwarding.protocols[protocol] = enable + + if protocol == header.IPv6ProtocolNumber { + if enable { + for _, nic := range s.nics { + nic.becomeIPv6Router() + } + } else { + for _, nic := range s.nics { + nic.becomeIPv6Host() + } } } } -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding +// Forwarding returns if packet forwarding between NICs is enabled. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + s.forwarding.RLock() + defer s.forwarding.RUnlock() + return s.forwarding.protocols[protocol] } // SetRouteTable assigns the route table to be used by this stack. It diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 60b54c244..7669ba672 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2091,7 +2091,7 @@ func TestNICForwarding(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID1, ep1); err != nil { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ef3457e32..64e44bc99 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -549,7 +549,7 @@ func TestTransportForwarding(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() |