summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/nic.go7
-rw-r--r--pkg/tcpip/stack/stack.go75
-rw-r--r--pkg/tcpip/stack/stack_test.go5
-rw-r--r--pkg/tcpip/stack/transport_test.go2
4 files changed, 72 insertions, 17 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a867f8c00..ab6798aa6 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -780,7 +780,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// packet and forward it to the NIC.
//
// TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding() {
+ if n.stack.Forwarding(protocol) {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
n.stack.stats.IP.InvalidAddressesReceived.Increment()
@@ -805,9 +805,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
} else {
// n doesn't have a destination endpoint.
// Send the packet out of n.
- hdr := buffer.NewPrependableFromView(vv.First())
+ // If we want to send the packet to a link-layer,
+ // we have to reserve space for an Ethernet header.
+ hdr := buffer.NewPrependableFromView(vv.First(), int(n.linkEP.MaxHeaderLength()))
vv.RemoveFirst()
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
// TODO(b/128629022): use route.WritePacket.
if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 242d2150c..71e0618f4 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -21,7 +21,9 @@ package stack
import (
"encoding/binary"
+ "math"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/time/rate"
@@ -48,6 +50,42 @@ const (
DefaultTOS = 0
)
+const (
+ // fakeNetNumber is used as a protocol number in tests.
+ //
+ // This constant should match fakeNetNumber in stack_test.go.
+ fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+)
+
+type forwardingFlag uint32
+
+// Packet forwarding flags. Forwarding settings for different network protocols
+// are stored as bit flags in an uint32 number.
+const (
+ forwardingIPv4 forwardingFlag = 1 << iota
+ forwardingIPv6
+
+ // forwardingFake is used to test package forwarding with a fake protocol.
+ forwardingFake
+)
+
+func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag {
+ var flag forwardingFlag
+ switch protocol {
+ case header.IPv4ProtocolNumber:
+ flag = forwardingIPv4
+ case header.IPv6ProtocolNumber:
+ flag = forwardingIPv6
+ case fakeNetNumber:
+ // This network protocol number is used in stack_test to test
+ // packet forwarding.
+ flag = forwardingFake
+ default:
+ // We only support forwarding for IPv4 and IPv6.
+ }
+ return flag
+}
+
type transportProtocolState struct {
proto TransportProtocol
defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
@@ -363,7 +401,10 @@ type Stack struct {
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
- forwarding bool
+
+ // forwarding contains the enable bits for packet forwarding for different
+ // network protocols.
+ forwarding uint32
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -630,20 +671,28 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables the packet forwarding between NICs.
-func (s *Stack) SetForwarding(enable bool) {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.Lock()
- s.forwarding = enable
- s.mu.Unlock()
+// SetForwarding enables or disables packet forwarding between NICs.
+func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) {
+ flag := getForwardingFlag(protocol)
+ for {
+ forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding))
+ var newValue forwardingFlag
+ if enable {
+ newValue = forwarding | flag
+ } else {
+ newValue = forwarding & ^flag
+ }
+ if atomic.CompareAndSwapUint32(&s.forwarding, uint32(forwarding), uint32(newValue)) {
+ break
+ }
+ }
}
-// Forwarding returns if the packet forwarding between NICs is enabled.
-func (s *Stack) Forwarding() bool {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.forwarding
+// Forwarding returns if packet forwarding between NICs is enabled.
+func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
+ flag := getForwardingFlag(protocol)
+ forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding))
+ return forwarding & flag != 0
}
// SetRouteTable assigns the route table to be used by this stack. It
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9dae853d0..ef3d1beb0 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -36,6 +36,9 @@ import (
)
const (
+ // fakeNetNumber is used as a protocol number in tests.
+ //
+ // This constant should match fakeNetNumber in stack.go.
fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
fakeNetHeaderLen = 12
fakeDefaultPrefixLen = 8
@@ -1825,7 +1828,7 @@ func TestNICForwarding(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 86c62be25..6d3daed24 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -528,7 +528,7 @@ func TestTransportForwarding(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
// TODO(b/123449044): Change this to a channel NIC.
ep1 := loopback.New()