diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 75 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 |
4 files changed, 72 insertions, 17 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a867f8c00..ab6798aa6 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -780,7 +780,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidAddressesReceived.Increment() @@ -805,9 +805,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(vv.First()) + // If we want to send the packet to a link-layer, + // we have to reserve space for an Ethernet header. + hdr := buffer.NewPrependableFromView(vv.First(), int(n.linkEP.MaxHeaderLength())) vv.RemoveFirst() + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. // TODO(b/128629022): use route.WritePacket. if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 242d2150c..71e0618f4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -21,7 +21,9 @@ package stack import ( "encoding/binary" + "math" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" @@ -48,6 +50,42 @@ const ( DefaultTOS = 0 ) +const ( + // fakeNetNumber is used as a protocol number in tests. + // + // This constant should match fakeNetNumber in stack_test.go. + fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 +) + +type forwardingFlag uint32 + +// Packet forwarding flags. Forwarding settings for different network protocols +// are stored as bit flags in an uint32 number. +const ( + forwardingIPv4 forwardingFlag = 1 << iota + forwardingIPv6 + + // forwardingFake is used to test package forwarding with a fake protocol. + forwardingFake +) + +func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag { + var flag forwardingFlag + switch protocol { + case header.IPv4ProtocolNumber: + flag = forwardingIPv4 + case header.IPv6ProtocolNumber: + flag = forwardingIPv6 + case fakeNetNumber: + // This network protocol number is used in stack_test to test + // packet forwarding. + flag = forwardingFake + default: + // We only support forwarding for IPv4 and IPv6. + } + return flag +} + type transportProtocolState struct { proto TransportProtocol defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool @@ -363,7 +401,10 @@ type Stack struct { mu sync.RWMutex nics map[tcpip.NICID]*NIC - forwarding bool + + // forwarding contains the enable bits for packet forwarding for different + // network protocols. + forwarding uint32 // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -630,20 +671,28 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - s.forwarding = enable - s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) { + flag := getForwardingFlag(protocol) + for { + forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding)) + var newValue forwardingFlag + if enable { + newValue = forwarding | flag + } else { + newValue = forwarding & ^flag + } + if atomic.CompareAndSwapUint32(&s.forwarding, uint32(forwarding), uint32(newValue)) { + break + } + } } -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding +// Forwarding returns if packet forwarding between NICs is enabled. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + flag := getForwardingFlag(protocol) + forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding)) + return forwarding & flag != 0 } // SetRouteTable assigns the route table to be used by this stack. It diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9dae853d0..ef3d1beb0 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -36,6 +36,9 @@ import ( ) const ( + // fakeNetNumber is used as a protocol number in tests. + // + // This constant should match fakeNetNumber in stack.go. fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 fakeNetHeaderLen = 12 fakeDefaultPrefixLen = 8 @@ -1825,7 +1828,7 @@ func TestNICForwarding(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 86c62be25..6d3daed24 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -528,7 +528,7 @@ func TestTransportForwarding(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() |