diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/forwarder_test.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 106 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 |
7 files changed, 108 insertions, 58 deletions
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 5a684eb9d..9dff23623 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -16,7 +16,6 @@ package stack import ( "encoding/binary" - "math" "testing" "time" @@ -26,9 +25,8 @@ import ( ) const ( - fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fwdTestNetHeaderLen = 12 - fwdTestNetDefaultPrefixLen = 8 + fwdTestNetHeaderLen = 12 + fwdTestNetDefaultPrefixLen = 8 // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, // except where another value is explicitly used. It is chosen to match @@ -92,7 +90,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH b[srcAddrOffset] = r.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) - return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements LinkEndpoint.WritePackets. @@ -118,7 +116,7 @@ type fwdTestNetworkProtocol struct { var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber + return fakeNetNumber } func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { @@ -179,7 +177,7 @@ func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip } func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber + return fakeNetNumber } // fwdTestPacketInfo holds all the information about an outbound packet. @@ -309,7 +307,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f proto.addrCache = s.linkAddrCache // Enable forwarding. - s.SetForwarding(true) + s.SetForwarding(proto.Number(), true) // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ @@ -320,7 +318,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } @@ -333,7 +331,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { t.Fatal("AddAddress #2 failed:", err) } @@ -368,7 +366,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) @@ -405,7 +403,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) @@ -436,7 +434,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) @@ -466,7 +464,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) @@ -474,7 +472,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // forwarded to NIC 2. buf = buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) @@ -515,7 +513,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) } @@ -561,7 +559,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) } @@ -619,7 +617,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) } diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index b0873d1af..97ca00d16 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -817,7 +817,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a stack-wide configuration, so we check // stack's forwarding flag to determine if the NIC is a routing // interface. - if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) { return } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 21bf53010..1a6724c31 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1120,7 +1120,7 @@ func TestNoRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1365,7 +1365,7 @@ func TestNoPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1723,7 +1723,7 @@ func TestNoAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -4580,7 +4580,7 @@ func TestCleanupNDPState(t *testing.T) { name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, keepAutoGenLinkLocal: true, maxAutoGenAddrEvents: 4, @@ -5226,11 +5226,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(false) + s.SetForwarding(ipv6.ProtocolNumber, false) }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, }, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 728292782..e74d2562a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -327,7 +327,7 @@ func (n *NIC) enable() *tcpip.Error { // does. That is, routers do not learn from RAs (e.g. on-link prefixes // and default routers). Therefore, soliciting RAs from other routers on // a link is unnecessary for routers. - if !n.stack.forwarding { + if !n.stack.Forwarding(header.IPv6ProtocolNumber) { n.mu.ndp.startSolicitingRouters() } @@ -1256,7 +1256,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() @@ -1283,6 +1283,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. // TODO(b/128629022): move this logic to route.WritePacket. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a3f87c8af..814b3e94a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "bytes" "encoding/binary" + "math" mathrand "math/rand" "sync/atomic" "time" @@ -50,6 +51,41 @@ const ( DefaultTOS = 0 ) +const ( + // fakeNetNumber is used as a protocol number in tests. + // + // This constant should match fakeNetNumber in stack_test.go. + fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 +) + +type forwardingFlag uint32 + +// Packet forwarding flags. Forwarding settings for different network protocols +// are stored as bit flags in an uint32 number. +const ( + forwardingIPv4 forwardingFlag = 1 << iota + forwardingIPv6 + + // forwardingFake is used to test package forwarding with a fake protocol. + forwardingFake +) + +func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag { + var flag forwardingFlag + switch protocol { + case header.IPv4ProtocolNumber: + flag = forwardingIPv4 + case header.IPv6ProtocolNumber: + flag = forwardingIPv6 + case fakeNetNumber: + // This network protocol number is used to test packet forwarding. + flag = forwardingFake + default: + // We only support forwarding for IPv4 and IPv6. + } + return flag +} + type transportProtocolState struct { proto TransportProtocol defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool @@ -415,9 +451,16 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + + // forwarding contains the enable bits for packet forwarding for different + // network protocols. + forwarding struct { + sync.RWMutex + flag forwardingFlag + } + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), @@ -851,46 +894,51 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -// -// When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. -// When forwarding becomes disabled and if IPv6 is enabled, NDP Router -// Solicitations will be stopped. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - defer s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) { + s.forwarding.Lock() + defer s.forwarding.Unlock() - // If forwarding status didn't change, do nothing further. - if s.forwarding == enable { + // If this stack does not support the protocol, do nothing. + if _, ok := s.networkProtocols[protocol]; !ok { return } - s.forwarding = enable + flag := getForwardingFlag(protocol) - // If this stack does not support IPv6, do nothing further. - if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { + // If the forwarding value for this protocol hasn't changed then do + // nothing. + if s.forwarding.flag&getForwardingFlag(protocol) != 0 == enable { return } + var newValue forwardingFlag if enable { - for _, nic := range s.nics { - nic.becomeIPv6Router() - } + newValue = s.forwarding.flag | flag } else { - for _, nic := range s.nics { - nic.becomeIPv6Host() + newValue = s.forwarding.flag & ^flag + } + s.forwarding.flag = newValue + + // Enable or disable NDP for IPv6. + if protocol == header.IPv6ProtocolNumber { + if enable { + for _, nic := range s.nics { + nic.becomeIPv6Router() + } + } else { + for _, nic := range s.nics { + nic.becomeIPv6Host() + } } } } -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding +// Forwarding returns if packet forwarding between NICs is enabled. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + s.forwarding.RLock() + defer s.forwarding.RUnlock() + return s.forwarding.flag&getForwardingFlag(protocol) != 0 } // SetRouteTable assigns the route table to be used by this stack. It diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 106645c50..f168be402 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -42,6 +42,9 @@ import ( ) const ( + // fakeNetNumber is used as a protocol number in tests. + // + // This constant should match fakeNetNumber in stack.go. fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 fakeNetHeaderLen = 12 fakeDefaultPrefixLen = 8 @@ -2125,7 +2128,7 @@ func TestNICForwarding(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID1, ep1); err != nil { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 6c6e44468..fa4b14ba6 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -580,7 +580,7 @@ func TestTransportForwarding(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() |