diff options
Diffstat (limited to 'src/ratelimiter.c')
-rw-r--r-- | src/ratelimiter.c | 249 |
1 files changed, 153 insertions, 96 deletions
diff --git a/src/ratelimiter.c b/src/ratelimiter.c index ab8f93d..2d2e758 100644 --- a/src/ratelimiter.c +++ b/src/ratelimiter.c @@ -1,138 +1,195 @@ /* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ #include "ratelimiter.h" -#include "peer.h" -#include "device.h" - -#include <linux/module.h> -#include <linux/netfilter/x_tables.h> +#include <linux/siphash.h> +#include <linux/vmalloc.h> +#include <linux/slab.h> +#include <linux/hashtable.h> #include <net/ip.h> -static struct xt_match *v4_match __read_mostly; +static struct kmem_cache *entry_cache; +static hsiphash_key_t key; +static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock"); +static atomic64_t refcnt = ATOMIC64_INIT(0); +static atomic_t total_entries = ATOMIC_INIT(0); +static unsigned int max_entries, table_size; +static void gc_entries(struct work_struct *); +static DECLARE_DEFERRABLE_WORK(gc_work, gc_entries); +static struct hlist_head *table_v4; #if IS_ENABLED(CONFIG_IPV6) -static struct xt_match *v6_match __read_mostly; +static struct hlist_head *table_v6; #endif +struct entry { + u64 last_time_ns, tokens; + void *net; + __be32 ip[3]; + spinlock_t lock; + struct hlist_node hash; + struct rcu_head rcu; +}; + enum { - RATELIMITER_PACKETS_PER_SECOND = 30, - RATELIMITER_PACKETS_BURSTABLE = 5 + PACKETS_PER_SECOND = 20, + PACKETS_BURSTABLE = 5, + PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND, + TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE }; -static inline void cfg_init(struct hashlimit_cfg1 *cfg, int family) +static void entry_free(struct rcu_head *rcu) { - memset(cfg, 0, sizeof(struct hashlimit_cfg1)); - if (family == NFPROTO_IPV4) - cfg->srcmask = 32; - else if (family == NFPROTO_IPV6) - cfg->srcmask = 96; - cfg->mode = XT_HASHLIMIT_HASH_SIP; /* source IP only -- we could also do source port by ORing this with XT_HASHLIMIT_HASH_SPT, but we don't really want to do that. It would also cause problems since we skb_pull early on, and hashlimit's nexthdr stuff isn't so nice. */ - cfg->avg = XT_HASHLIMIT_SCALE / RATELIMITER_PACKETS_PER_SECOND; /* 30 per second per IP */ - cfg->burst = RATELIMITER_PACKETS_BURSTABLE; /* Allow bursts of 5 at a time */ - cfg->gc_interval = 1000; /* same as expiration date */ - cfg->expire = 1000; /* Units of avg (seconds = 1) times 1000 */ - /* cfg->size and cfg->max are computed based on the memory size of left to zero */ + kmem_cache_free(entry_cache, container_of(rcu, struct entry, rcu)); + atomic_dec(&total_entries); } -int ratelimiter_init(struct ratelimiter *ratelimiter, struct wireguard_device *wg) +static void entry_uninit(struct entry *entry) { - struct net_device *dev = netdev_pub(wg); - struct xt_mtchk_param chk = { .net = wg->creating_net }; - int ret; - - memset(ratelimiter, 0, sizeof(struct ratelimiter)); - - cfg_init(&ratelimiter->v4_info.cfg, NFPROTO_IPV4); - memcpy(ratelimiter->v4_info.name, dev->name, IFNAMSIZ); - chk.matchinfo = &ratelimiter->v4_info; - chk.match = v4_match; - chk.family = NFPROTO_IPV4; - ret = v4_match->checkentry(&chk); - if (ret < 0) - return ret; - -#if IS_ENABLED(CONFIG_IPV6) - cfg_init(&ratelimiter->v6_info.cfg, NFPROTO_IPV6); - memcpy(ratelimiter->v6_info.name, dev->name, IFNAMSIZ); - chk.matchinfo = &ratelimiter->v6_info; - chk.match = v6_match; - chk.family = NFPROTO_IPV6; - ret = v6_match->checkentry(&chk); - if (ret < 0) { - struct xt_mtdtor_param dtor_v4 = { - .net = wg->creating_net, - .match = v4_match, - .matchinfo = &ratelimiter->v4_info, - .family = NFPROTO_IPV4 - }; - v4_match->destroy(&dtor_v4); - return ret; - } -#endif - - ratelimiter->net = wg->creating_net; - return 0; + hlist_del_rcu(&entry->hash); + call_rcu_bh(&entry->rcu, entry_free); } -void ratelimiter_uninit(struct ratelimiter *ratelimiter) +/* Calling this function with a NULL work uninits all entries. */ +static void gc_entries(struct work_struct *work) { - struct xt_mtdtor_param dtor = { .net = ratelimiter->net }; - - dtor.match = v4_match; - dtor.matchinfo = &ratelimiter->v4_info; - dtor.family = NFPROTO_IPV4; - v4_match->destroy(&dtor); - + unsigned int i; + struct entry *entry; + struct hlist_node *temp; + const u64 now = ktime_get_ns(); + + for (i = 0; i < table_size; ++i) { + spin_lock(&table_lock); + hlist_for_each_entry_safe (entry, temp, &table_v4[i], hash) { + if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC) + entry_uninit(entry); + } #if IS_ENABLED(CONFIG_IPV6) - dtor.match = v6_match; - dtor.matchinfo = &ratelimiter->v6_info; - dtor.family = NFPROTO_IPV6; - v6_match->destroy(&dtor); + hlist_for_each_entry_safe (entry, temp, &table_v6[i], hash) { + if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC) + entry_uninit(entry); + } #endif + spin_unlock(&table_lock); + if (likely(work)) + cond_resched(); + } + if (likely(work)) + queue_delayed_work(system_power_efficient_wq, &gc_work, HZ); } -bool ratelimiter_allow(struct ratelimiter *ratelimiter, struct sk_buff *skb) +bool ratelimiter_allow(struct sk_buff *skb, struct net *net) { - struct xt_action_param action = { { NULL } }; - if (unlikely(skb->len < sizeof(struct iphdr))) - return false; - if (ip_hdr(skb)->version == 4) { - action.match = v4_match; - action.matchinfo = &ratelimiter->v4_info; - action.thoff = ip_hdrlen(skb); + struct entry *entry; + struct hlist_head *bucket; + struct { u32 net; __be32 ip[3]; } data = { .net = (unsigned long)net & 0xffffffff }; + + if (skb->len >= sizeof(struct iphdr) && ip_hdr(skb)->version == 4) { + data.ip[0] = ip_hdr(skb)->saddr; + bucket = &table_v4[hsiphash(&data, sizeof(u32) * 2, &key) & (table_size - 1)]; } #if IS_ENABLED(CONFIG_IPV6) - else if (ip_hdr(skb)->version == 6) { - action.match = v6_match; - action.matchinfo = &ratelimiter->v6_info; + else if (skb->len >= sizeof(struct ipv6hdr) && ip_hdr(skb)->version == 6) { + memcpy(data.ip, &ipv6_hdr(skb)->saddr, sizeof(u32) * 3); /* Only 96 bits */ + bucket = &table_v6[hsiphash(&data, sizeof(u32) * 4, &key) & (table_size - 1)]; } #endif else return false; - return action.match->match(skb, &action); + rcu_read_lock(); + hlist_for_each_entry_rcu (entry, bucket, hash) { + if (entry->net == net && !memcmp(entry->ip, data.ip, sizeof(data.ip))) { + u64 now, tokens; + bool ret; + /* Inspired by nft_limit.c, but this is actually a slightly different + * algorithm. Namely, we incorporate the burst as part of the maximum + * tokens, rather than as part of the rate. */ + spin_lock(&entry->lock); + now = ktime_get_ns(); + tokens = min_t(u64, TOKEN_MAX, entry->tokens + now - entry->last_time_ns); + entry->last_time_ns = now; + ret = tokens >= PACKET_COST; + entry->tokens = ret ? tokens - PACKET_COST : tokens; + spin_unlock(&entry->lock); + rcu_read_unlock(); + return ret; + } + } + rcu_read_unlock(); + + if (atomic_inc_return(&total_entries) > max_entries) + goto err_oom; + + entry = kmem_cache_alloc(entry_cache, GFP_KERNEL); + if (!entry) + goto err_oom; + + entry->net = net; + memcpy(entry->ip, data.ip, sizeof(data.ip)); + INIT_HLIST_NODE(&entry->hash); + spin_lock_init(&entry->lock); + entry->last_time_ns = ktime_get_ns(); + entry->tokens = TOKEN_MAX - PACKET_COST; + spin_lock(&table_lock); + hlist_add_head_rcu(&entry->hash, bucket); + spin_unlock(&table_lock); + return true; + +err_oom: + atomic_dec(&total_entries); + return false; } -int ratelimiter_module_init(void) +int ratelimiter_init(void) { - v4_match = xt_request_find_match(NFPROTO_IPV4, "hashlimit", 1); - if (IS_ERR(v4_match)) { - pr_err("The xt_hashlimit module for IPv4 is required\n"); - return PTR_ERR(v4_match); - } + if (atomic64_inc_return(&refcnt) != 1) + return 0; + + entry_cache = kmem_cache_create("wireguard_ratelimiter", sizeof(struct entry), 0, 0, NULL); + if (!entry_cache) + goto err; + + /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting, + * but what it shares in common is that it uses a massive hashtable. So, + * we borrow their wisdom about good table sizes on different systems + * dependent on RAM. This calculation here comes from there. */ + table_size = (totalram_pages > (1 << 30) / PAGE_SIZE) ? 8192 : max_t(unsigned long, 16, roundup_pow_of_two((totalram_pages << PAGE_SHIFT) / (1 << 14) / sizeof(struct hlist_head))); + max_entries = table_size * 8; + + table_v4 = vmalloc(table_size * sizeof(struct hlist_head)); + if (!table_v4) + goto err_kmemcache; + __hash_init(table_v4, table_size); + #if IS_ENABLED(CONFIG_IPV6) - v6_match = xt_request_find_match(NFPROTO_IPV6, "hashlimit", 1); - if (IS_ERR(v6_match)) { - pr_err("The xt_hashlimit module for IPv6 is required\n"); - module_put(v4_match->me); - return PTR_ERR(v6_match); + table_v6 = vmalloc(table_size * sizeof(struct hlist_head)); + if (!table_v6) { + vfree(table_v4); + goto err_kmemcache; } + __hash_init(table_v6, table_size); #endif + + queue_delayed_work(system_power_efficient_wq, &gc_work, HZ); + get_random_bytes(&key, sizeof(key)); return 0; + +err_kmemcache: + kmem_cache_destroy(entry_cache); +err: + atomic64_dec(&refcnt); + return -ENOMEM; } -void ratelimiter_module_deinit(void) +void ratelimiter_uninit(void) { - module_put(v4_match->me); + if (atomic64_dec_return(&refcnt)) + return; + + cancel_delayed_work_sync(&gc_work); + gc_entries(NULL); + synchronize_rcu(); + vfree(table_v4); #if IS_ENABLED(CONFIG_IPV6) - module_put(v6_match->me); + vfree(table_v6); #endif + kmem_cache_destroy(entry_cache); } |