From 352ae1022ce19de28fc72e034cc469872ad79d06 Mon Sep 17 00:00:00 2001 From: aleksej Date: Sun, 27 Oct 2019 15:14:35 +0300 Subject: Add /proc/sys/net/ipv4/ip_forward --- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/sys_net.go | 105 +++++++++++++++++++++++++++++++++++ pkg/sentry/fs/proc/sys_net_state.go | 21 ++++++- pkg/sentry/fs/proc/sys_net_test.go | 68 +++++++++++++++++++++++ pkg/sentry/inet/BUILD | 5 +- pkg/sentry/inet/inet.go | 8 +++ pkg/sentry/inet/test_stack.go | 16 ++++++ pkg/sentry/socket/hostinet/BUILD | 3 + pkg/sentry/socket/hostinet/stack.go | 30 ++++++++++ pkg/sentry/socket/netstack/stack.go | 24 ++++++++ pkg/sentry/socket/rpcinet/stack.go | 11 ++++ pkg/tcpip/buffer/BUILD | 5 +- pkg/tcpip/buffer/prependable.go | 18 ++++-- pkg/tcpip/buffer/prependable_test.go | 50 +++++++++++++++++ pkg/tcpip/network/ipv4/ipv4.go | 4 +- pkg/tcpip/stack/nic.go | 7 ++- pkg/tcpip/stack/stack.go | 75 ++++++++++++++++++++----- pkg/tcpip/stack/stack_test.go | 5 +- pkg/tcpip/stack/transport_test.go | 2 +- test/syscalls/linux/proc_net.cc | 36 ++++++++++++ 20 files changed, 467 insertions(+), 27 deletions(-) create mode 100644 pkg/tcpip/buffer/prependable_test.go diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 75cbb0622..f21e2a65c 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -53,6 +53,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/tcpip/network/ipv4", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index f3b63dfc2..794723d9c 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/waiter" ) @@ -280,11 +281,115 @@ func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.S return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil) } +// ipForwarding implements fs.InodeOperations. +// +// ipForwarding is used to enable/disable packet forwarding of netstack. +// +// +stateify savable +type ipForwarding struct { + stack inet.Stack `state:".(ipForwardingState)"` + fsutil.SimpleFileInode +} + +// ipForwardingState is used to stores a state of netstack +// for packet forwarding because netstack itself is stateless. +// +// +stateify savable +type ipForwardingState struct { + stack inet.Stack `state:"wait"` + + // enabled stores packet forwarding settings during save, and sets it back + // in netstack in restore. We must save/restore this here, since + // netstack itself is stateless. + enabled bool +} + +func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ipf := &ipForwarding{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ipf, msrc, sattr) +} + +// +stateify savable +type ipForwardingFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + stack inet.Stack `state:"wait"` +} + +// GetFile implements fs.InodeOperations.GetFile. +func (ipf *ipForwarding) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &ipForwardingFile{ + stack: ipf.stack, + }), nil +} + +// Read implements fs.FileOperations.Read. +func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + val := "0\n" + if f.stack.Forwarding(ipv4.ProtocolNumber) { + // Technically, this is not quite compatible with Linux. Linux + // stores these as an integer, so if you write "2" into + // ip_forward, you should get 2 back. + val = "1\n" + } + + n, err := dst.CopyOut(ctx, []byte(val)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +// +// Offset is ignored, multiple writes are not supported. +func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Only consider size of one memory page for input. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return n, err + } + + enabled := v != 0 + return n, f.stack.SetForwarding(ipv4.ProtocolNumber, enabled) +} + func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { contents := map[string]*fs.Inode{ // Add tcp_sack. "tcp_sack": newTCPSackInode(ctx, msrc, s), + // Add ip_forward. + "ip_forward": newIPForwardingInode(ctx, msrc, s), + // The following files are simple stubs until they are // implemented in netstack, most of these files are // configuration related. We use the value closest to the diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go index 6eba709c6..02e43297f 100644 --- a/pkg/sentry/fs/proc/sys_net_state.go +++ b/pkg/sentry/fs/proc/sys_net_state.go @@ -14,7 +14,10 @@ package proc -import "fmt" +import ( + "fmt" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" +) // beforeSave is invoked by stateify. func (t *tcpMemInode) beforeSave() { @@ -40,3 +43,19 @@ func (s *tcpSack) afterLoad() { } } } + +// saveStack is invoked by stateify. +func (ipf *ipForwarding) saveStack() ipForwardingState { + return ipForwardingState{ + ipf.stack, + ipf.stack.Forwarding(ipv4.ProtocolNumber), + } +} + +// loadStack is invoked by stateify. +func (ipf *ipForwarding) loadStack(s ipForwardingState) { + ipf.stack = s.stack + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, s.enabled); err != nil { + panic(fmt.Sprintf("failed to set previous IPv4 forwarding configuration [%v]: %v", s.enabled, err)) + } +} diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go index 6abae7a60..6e51dfbb7 100644 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ b/pkg/sentry/fs/proc/sys_net_test.go @@ -123,3 +123,71 @@ func TestConfigureRecvBufferSize(t *testing.T) { } } } + +func TestConfigureIPForwarding(t *testing.T) { + ctx := context.Background() + s := inet.NewTestStack() + + var cases = []struct { + comment string + initial bool + str string + final bool + }{ + { + comment: `Forwarding is disabled; write 1 and enable forwarding`, + initial: false, + str: "1", + final: true, + }, + { + comment: `Forwarding is disabled; write 0 and disable forwarding`, + initial: false, + str: "0", + final: false, + }, + { + comment: `Forwarding is enabled; write 1 and enable forwarding`, + initial: true, + str: "1", + final: true, + }, + { + comment: `Forwarding is enabled; write 0 and disable forwarding`, + initial: true, + str: "0", + final: false, + }, + { + comment: `Forwarding is disabled; write 2404 and enable forwarding`, + initial: false, + str: "2404", + final: true, + }, + { + comment: `Forwarding is enabled; write 2404 and enable forwarding`, + initial: true, + str: "2404", + final: true, + }, + } + for _, c := range cases { + t.Run(c.comment, func(t *testing.T) { + s.IPForwarding = c.initial + + file := &ipForwardingFile{stack: s} + + // Write the values. + src := usermem.BytesIOSequence([]byte(c.str)) + if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { + t.Errorf("file.Write(ctx, nil, %v, 0) = (%d, %v); wanted (%d, nil)", c.str, n, err, len(c.str)) + } + + // Read the values from the stack and check them. + if s.IPForwarding != c.final { + t.Errorf("s.IPForwarding = %v; wanted %v", s.IPForwarding, c.final) + } + + }) + } +} diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index d5284f0d9..99481e05e 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -13,5 +13,8 @@ go_library( "test_stack.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/inet", - deps = ["//pkg/sentry/context"], + deps = [ + "//pkg/sentry/context", + "//pkg/tcpip", + ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index bc6cb1095..6217100b2 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,6 +15,8 @@ // Package inet defines semantics for IP stacks. package inet +import "gvisor.dev/gvisor/pkg/tcpip" + // Stack represents a TCP/IP stack. type Stack interface { // Interfaces returns all network interfaces as a mapping from interface @@ -58,6 +60,12 @@ type Stack interface { // Resume restarts the network stack after restore. Resume() + + // Forwarding returns if packet forwarding between NICs is enabled. + Forwarding(protocol tcpip.NetworkProtocolNumber) bool + + // SetForwarding enables or disables packet forwarding between NICs. + SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index b9eed7c3a..c6907cfcb 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,6 +14,10 @@ package inet +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + // TestStack is a dummy implementation of Stack for tests. type TestStack struct { InterfacesMap map[int32]Interface @@ -23,6 +27,7 @@ type TestStack struct { TCPRecvBufSize TCPBufferSize TCPSendBufSize TCPBufferSize TCPSACKFlag bool + IPForwarding bool } // NewTestStack returns a TestStack with no network interfaces. The value of @@ -96,3 +101,14 @@ func (s *TestStack) RouteTable() []Route { // Resume implements Stack.Resume. func (s *TestStack) Resume() { } + +// Forwarding implements inet.Stack.Forwarding. +func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + return s.IPForwarding +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + s.IPForwarding = enable + return nil +} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 4d174dda4..c1b20eaf8 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -32,6 +32,9 @@ go_library( "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index d4387f5d4..4b460d30e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -31,6 +31,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -57,6 +60,8 @@ type Stack struct { tcpSACKEnabled bool netDevFile *os.File netSNMPFile *os.File + ipv4Forwarding bool + ipv6Forwarding bool } // NewStack returns an empty Stack containing no configuration. @@ -116,6 +121,13 @@ func (s *Stack) Configure() error { s.netSNMPFile = f } + s.ipv4Forwarding = false + if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil { + s.ipv4Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" + } else { + log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false") + } + return nil } @@ -442,3 +454,21 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + switch protocol { + case ipv4.ProtocolNumber: + return s.ipv4Forwarding + case ipv6.ProtocolNumber: + return s.ipv6Forwarding + default: + log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol) + return false + } +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + return syserror.EACCES +} diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index d5db8c17c..d0102cfa3 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -291,3 +292,26 @@ func (s *Stack) FillDefaultIPTables() { func (s *Stack) Resume() { s.Stack.Resume() } + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + switch protocol { + case ipv4.ProtocolNumber, ipv6.ProtocolNumber: + return s.Stack.Forwarding(protocol) + default: + log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol) + return false + } +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + switch protocol { + case ipv4.ProtocolNumber, ipv6.ProtocolNumber: + s.Stack.SetForwarding(protocol, enable) + default: + log.Warningf("SetForwarding(%v) failed: unsupported protocol", protocol) + return syserr.ErrProtocolNotSupported.ToError() + } + return nil +} diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 5dcb6b455..f5441b826 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/unet" ) @@ -165,3 +166,13 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + panic("rpcinet handles procfs directly this method should not be called") +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + panic("rpcinet handles procfs directly this method should not be called") +} diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index d6c31bfa2..a7bf0c4dc 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -16,6 +16,9 @@ go_library( go_test( name = "buffer_test", size = "small", - srcs = ["view_test.go"], + srcs = [ + "prependable_test.go", + "view_test.go", + ], embed = [":buffer"], ) diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 48a2a2713..2f9a23d61 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -32,13 +32,19 @@ func NewPrependable(size int) Prependable { return Prependable{buf: NewView(size), usedIdx: size} } -// NewPrependableFromView creates an entirely-used Prependable from a View. +// NewPrependableFromView creates a Prependable from a View and allocates +// additional space if needed. // -// NewPrependableFromView takes ownership of v. Note that since the entire -// prependable is used, further attempts to call Prepend will note that size > -// p.usedIdx and return nil. -func NewPrependableFromView(v View) Prependable { - return Prependable{buf: v, usedIdx: 0} +// NewPrependableFromView takes ownership of v. Note that if the entire +// prependable is used, further attempts to call Prepend will note that +// size > p.usedIdx and return nil. +func NewPrependableFromView(v View, extraCap int) Prependable { + if extraCap == 0 { + return Prependable{buf: v, usedIdx: 0} + } + buf := make([]byte, extraCap, extraCap + len(v)) + buf = append(buf, v...) + return Prependable{buf: buf, usedIdx: extraCap} } // NewEmptyPrependableFromView creates a new prependable buffer from a View. diff --git a/pkg/tcpip/buffer/prependable_test.go b/pkg/tcpip/buffer/prependable_test.go new file mode 100644 index 000000000..43660c307 --- /dev/null +++ b/pkg/tcpip/buffer/prependable_test.go @@ -0,0 +1,50 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package buffer + +import ( + "reflect" + "testing" +) + +func TestNewPrependableFromView(t *testing.T) { + tests := []struct { + comment string + view View + extraSize int + want Prependable + }{ + { + comment: "Reserve extra space", + view: View("abc"), + extraSize: 2, + want: Prependable{buf: View("\x00\x00abc"), usedIdx: 2}, + }, + { + comment: "Don't reserve extra space", + view: View("abc"), + extraSize: 0, + want: Prependable{buf: View("abc"), usedIdx: 0}, + }, + } + + for _, testCase := range tests { + t.Run(testCase.comment, func(t *testing.T) { + prep := NewPrependableFromView(testCase.view, testCase.extraSize) + if !reflect.DeepEqual(prep, testCase.want) { + t.Errorf("NewPrependableFromView(%#v, %d) = %#v; want %#v", testCase.view, testCase.extraSize, prep, testCase.want) + } + } ) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1339f8474..90f4406e5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -307,7 +307,9 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect return nil } - hdr := buffer.NewPrependableFromView(payload.ToView()) + // If we want to send the packet to a link-layer, + // we have to reserve space for an Ethernet header. + hdr := buffer.NewPrependableFromView(payload.ToView(), int(e.linkEP.MaxHeaderLength())) r.Stats().IP.PacketsSent.Increment() return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a867f8c00..ab6798aa6 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -780,7 +780,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidAddressesReceived.Increment() @@ -805,9 +805,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(vv.First()) + // If we want to send the packet to a link-layer, + // we have to reserve space for an Ethernet header. + hdr := buffer.NewPrependableFromView(vv.First(), int(n.linkEP.MaxHeaderLength())) vv.RemoveFirst() + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. // TODO(b/128629022): use route.WritePacket. if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 242d2150c..71e0618f4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -21,7 +21,9 @@ package stack import ( "encoding/binary" + "math" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" @@ -48,6 +50,42 @@ const ( DefaultTOS = 0 ) +const ( + // fakeNetNumber is used as a protocol number in tests. + // + // This constant should match fakeNetNumber in stack_test.go. + fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 +) + +type forwardingFlag uint32 + +// Packet forwarding flags. Forwarding settings for different network protocols +// are stored as bit flags in an uint32 number. +const ( + forwardingIPv4 forwardingFlag = 1 << iota + forwardingIPv6 + + // forwardingFake is used to test package forwarding with a fake protocol. + forwardingFake +) + +func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag { + var flag forwardingFlag + switch protocol { + case header.IPv4ProtocolNumber: + flag = forwardingIPv4 + case header.IPv6ProtocolNumber: + flag = forwardingIPv6 + case fakeNetNumber: + // This network protocol number is used in stack_test to test + // packet forwarding. + flag = forwardingFake + default: + // We only support forwarding for IPv4 and IPv6. + } + return flag +} + type transportProtocolState struct { proto TransportProtocol defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool @@ -363,7 +401,10 @@ type Stack struct { mu sync.RWMutex nics map[tcpip.NICID]*NIC - forwarding bool + + // forwarding contains the enable bits for packet forwarding for different + // network protocols. + forwarding uint32 // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -630,20 +671,28 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - s.forwarding = enable - s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) { + flag := getForwardingFlag(protocol) + for { + forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding)) + var newValue forwardingFlag + if enable { + newValue = forwarding | flag + } else { + newValue = forwarding & ^flag + } + if atomic.CompareAndSwapUint32(&s.forwarding, uint32(forwarding), uint32(newValue)) { + break + } + } } -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding +// Forwarding returns if packet forwarding between NICs is enabled. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + flag := getForwardingFlag(protocol) + forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding)) + return forwarding & flag != 0 } // SetRouteTable assigns the route table to be used by this stack. It diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9dae853d0..ef3d1beb0 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -36,6 +36,9 @@ import ( ) const ( + // fakeNetNumber is used as a protocol number in tests. + // + // This constant should match fakeNetNumber in stack.go. fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 fakeNetHeaderLen = 12 fakeDefaultPrefixLen = 8 @@ -1825,7 +1828,7 @@ func TestNICForwarding(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 86c62be25..6d3daed24 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -528,7 +528,7 @@ func TestTransportForwarding(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 65bad06d4..897cf4950 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -326,6 +326,42 @@ TEST(ProcNetSnmp, UdpIn) { EXPECT_EQ(oldInDatagrams, newInDatagrams - 1); } +TEST(ProcSysNetIpv4IpForward, Exists) { + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/ip_forward", O_RDWR)); +} + +TEST(ProcSysNetIpv4IpForward, DefaultValueEqZero) { + auto const fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/ip_forward", O_RDWR)); + + char buf = 101; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_TRUE(buf == '0') << "unexpected ip_forward: " << buf; +} + +TEST(ProcSysNetIpv4IpForward, CanReadAndWrite) { + auto const fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/ip_forward", O_RDWR)); + + char buf = 101; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_TRUE(buf == '0') << "unexpected ip_forward: " << buf; + + constexpr char to_write = '1'; + EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), + SyscallSucceedsWithValue(sizeof(to_write))); + + buf = 101; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + EXPECT_EQ(buf, to_write); +} + } // namespace } // namespace testing } // namespace gvisor -- cgit v1.2.3