diff options
-rw-r--r-- | src/Kbuild | 5 | ||||
-rw-r--r-- | src/Kconfig | 5 | ||||
-rw-r--r-- | src/compat/compat.h | 21 | ||||
-rw-r--r-- | src/cookie.c | 14 | ||||
-rw-r--r-- | src/cookie.h | 5 | ||||
-rw-r--r-- | src/device.c | 8 | ||||
-rw-r--r-- | src/main.c | 6 | ||||
-rw-r--r-- | src/ratelimiter.c | 249 | ||||
-rw-r--r-- | src/ratelimiter.h | 22 | ||||
-rw-r--r-- | src/tests/debug.mk | 2 | ||||
-rw-r--r-- | src/tests/qemu/kernel.config | 1 |
11 files changed, 179 insertions, 159 deletions
@@ -29,11 +29,6 @@ CONFIG_WIREGUARD := m ifneq ($(CONFIG_SMP),) ccflags-y += -DCONFIG_WIREGUARD_PARALLEL=y endif -ifneq ($(CONFIG_MODULES),) -ifeq ($(CONFIG_NETFILTER_XT_MATCH_HASHLIMIT),) -$(error "WireGuard requires CONFIG_NETFILTER_XT_MATCH_HASHLIMIT to be configured in your kernel. See https://www.wireguard.io/install/#kernel-requirements for more info") -endif -endif endif include $(src)/compat/Kbuild.include diff --git a/src/Kconfig b/src/Kconfig index 5b738ab..f2e25bb 100644 --- a/src/Kconfig +++ b/src/Kconfig @@ -2,12 +2,7 @@ config WIREGUARD tristate "IP: WireGuard secure network tunnel" depends on NET && INET select NET_UDP_TUNNEL - select NETFILTER_XT_MATCH_HASHLIMIT - select NETFILTER - select NETFILTER_XTABLES - select NETFILTER_ADVANCED select CRYPTO_BLKCIPHER - select IP6_NF_IPTABLES if IPV6 select NEON select KERNEL_MODE_NEON default m diff --git a/src/compat/compat.h b/src/compat/compat.h index 6c1bfa3..feb4347 100644 --- a/src/compat/compat.h +++ b/src/compat/compat.h @@ -11,15 +11,6 @@ #error "WireGuard requires Linux >= 3.10" #endif -/* These conditionals can't be enforced by an out of tree module very easily, - * so we stick them here in compat instead. */ -#if !IS_ENABLED(CONFIG_NETFILTER_XT_MATCH_HASHLIMIT) -#error "WireGuard requires CONFIG_NETFILTER_XT_MATCH_HASHLIMIT." -#endif -#if IS_ENABLED(CONFIG_IPV6) && !IS_ENABLED(CONFIG_IP6_NF_IPTABLES) -#error "WireGuard requires CONFIG_IP6_NF_IPTABLES when using CONFIG_IPV6." -#endif - #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 0, 0) && defined(CONFIG_X86_64) #define CONFIG_AS_SSSE3 #endif @@ -276,6 +267,18 @@ static inline int get_random_bytes_wait(void *buf, int nbytes) } #endif +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 11, 0) +#define system_power_efficient_wq system_unbound_wq +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 0) +#include <linux/ktime.h> +static inline u64 ktime_get_ns(void) +{ + return ktime_to_ns(ktime_get()); +} +#endif + /* https://lkml.org/lkml/2015/6/12/415 */ #include <linux/netdevice.h> static inline struct net_device *netdev_pub(void *dev) diff --git a/src/cookie.c b/src/cookie.c index ce22b53..0e9c211 100644 --- a/src/cookie.c +++ b/src/cookie.c @@ -4,6 +4,7 @@ #include "peer.h" #include "device.h" #include "messages.h" +#include "ratelimiter.h" #include "crypto/blake2s.h" #include "crypto/chacha20poly1305.h" @@ -11,16 +12,12 @@ #include <net/ipv6.h> #include <crypto/algapi.h> -int cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg) +void cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg) { - int ret = ratelimiter_init(&checker->ratelimiter, wg); - if (ret) - return ret; init_rwsem(&checker->secret_lock); checker->secret_birthdate = get_jiffies_64(); get_random_bytes(checker->secret, NOISE_HASH_LEN); checker->device = wg; - return 0; } enum { COOKIE_KEY_LABEL_LEN = 8 }; @@ -56,11 +53,6 @@ void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer) precompute_key(peer->latest_cookie.message_mac1_key, peer->handshake.remote_static, mac1_key_label); } -void cookie_checker_uninit(struct cookie_checker *checker) -{ - ratelimiter_uninit(&checker->ratelimiter); -} - void cookie_init(struct cookie *cookie) { memset(cookie, 0, sizeof(struct cookie)); @@ -127,7 +119,7 @@ enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, str goto out; ret = VALID_MAC_WITH_COOKIE_BUT_RATELIMITED; - if (!ratelimiter_allow(&checker->ratelimiter, skb)) + if (!ratelimiter_allow(skb, dev_net(netdev_pub(checker->device)))) goto out; ret = VALID_MAC_WITH_COOKIE; diff --git a/src/cookie.h b/src/cookie.h index c87d3dd..54d0d99 100644 --- a/src/cookie.h +++ b/src/cookie.h @@ -4,7 +4,6 @@ #define WGCOOKIE_H #include "messages.h" -#include "ratelimiter.h" #include <linux/rwsem.h> struct wireguard_peer; @@ -17,7 +16,6 @@ struct cookie_checker { u8 message_mac1_key[NOISE_SYMMETRIC_KEY_LEN]; u64 secret_birthdate; struct rw_semaphore secret_lock; - struct ratelimiter ratelimiter; struct wireguard_device *device; }; @@ -39,8 +37,7 @@ enum cookie_mac_state { VALID_MAC_WITH_COOKIE }; -int cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg); -void cookie_checker_uninit(struct cookie_checker *checker); +void cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg); void cookie_checker_precompute_device_keys(struct cookie_checker *checker); void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer); void cookie_init(struct cookie *cookie); diff --git a/src/device.c b/src/device.c index 7a2948a..1b975d9 100644 --- a/src/device.c +++ b/src/device.c @@ -5,6 +5,7 @@ #include "timers.h" #include "device.h" #include "config.h" +#include "ratelimiter.h" #include "peer.h" #include "uapi.h" #include "messages.h" @@ -251,10 +252,10 @@ static void destruct(struct net_device *dev) destroy_workqueue(wg->crypt_wq); #endif routing_table_free(&wg->peer_routing_table); + ratelimiter_uninit(); memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity)); skb_queue_purge(&wg->incoming_handshakes); socket_uninit(wg); - cookie_checker_uninit(&wg->cookie_checker); mutex_unlock(&wg->device_update_lock); free_percpu(dev->tstats); free_percpu(wg->incoming_handshakes_worker); @@ -314,6 +315,7 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t pubkey_hashtable_init(&wg->peer_hashtable); index_hashtable_init(&wg->index_hashtable); routing_table_init(&wg->peer_routing_table); + cookie_checker_init(&wg->cookie_checker, wg); INIT_LIST_HEAD(&wg->peer_list); dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); @@ -353,7 +355,7 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t padata_start(wg->decrypt_pd); #endif - ret = cookie_checker_init(&wg->cookie_checker, wg); + ret = ratelimiter_init(); if (ret < 0) goto error_8; @@ -368,8 +370,8 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 12, 0) error_9: + ratelimiter_uninit(); #endif - cookie_checker_uninit(&wg->cookie_checker); error_8: #ifdef CONFIG_WIREGUARD_PARALLEL padata_free(wg->decrypt_pd); @@ -26,10 +26,6 @@ static int __init mod_init(void) #endif noise_init(); - ret = ratelimiter_module_init(); - if (ret < 0) - return ret; - #ifdef CONFIG_WIREGUARD_PARALLEL ret = packet_init_data_caches(); if (ret < 0) @@ -50,7 +46,6 @@ err_device: packet_deinit_data_caches(); err_packet: #endif - ratelimiter_module_deinit(); return ret; } @@ -60,7 +55,6 @@ static void __exit mod_exit(void) #ifdef CONFIG_WIREGUARD_PARALLEL packet_deinit_data_caches(); #endif - ratelimiter_module_deinit(); pr_debug("WireGuard unloaded\n"); } diff --git a/src/ratelimiter.c b/src/ratelimiter.c index ab8f93d..2d2e758 100644 --- a/src/ratelimiter.c +++ b/src/ratelimiter.c @@ -1,138 +1,195 @@ /* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ #include "ratelimiter.h" -#include "peer.h" -#include "device.h" - -#include <linux/module.h> -#include <linux/netfilter/x_tables.h> +#include <linux/siphash.h> +#include <linux/vmalloc.h> +#include <linux/slab.h> +#include <linux/hashtable.h> #include <net/ip.h> -static struct xt_match *v4_match __read_mostly; +static struct kmem_cache *entry_cache; +static hsiphash_key_t key; +static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock"); +static atomic64_t refcnt = ATOMIC64_INIT(0); +static atomic_t total_entries = ATOMIC_INIT(0); +static unsigned int max_entries, table_size; +static void gc_entries(struct work_struct *); +static DECLARE_DEFERRABLE_WORK(gc_work, gc_entries); +static struct hlist_head *table_v4; #if IS_ENABLED(CONFIG_IPV6) -static struct xt_match *v6_match __read_mostly; +static struct hlist_head *table_v6; #endif +struct entry { + u64 last_time_ns, tokens; + void *net; + __be32 ip[3]; + spinlock_t lock; + struct hlist_node hash; + struct rcu_head rcu; +}; + enum { - RATELIMITER_PACKETS_PER_SECOND = 30, - RATELIMITER_PACKETS_BURSTABLE = 5 + PACKETS_PER_SECOND = 20, + PACKETS_BURSTABLE = 5, + PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND, + TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE }; -static inline void cfg_init(struct hashlimit_cfg1 *cfg, int family) +static void entry_free(struct rcu_head *rcu) { - memset(cfg, 0, sizeof(struct hashlimit_cfg1)); - if (family == NFPROTO_IPV4) - cfg->srcmask = 32; - else if (family == NFPROTO_IPV6) - cfg->srcmask = 96; - cfg->mode = XT_HASHLIMIT_HASH_SIP; /* source IP only -- we could also do source port by ORing this with XT_HASHLIMIT_HASH_SPT, but we don't really want to do that. It would also cause problems since we skb_pull early on, and hashlimit's nexthdr stuff isn't so nice. */ - cfg->avg = XT_HASHLIMIT_SCALE / RATELIMITER_PACKETS_PER_SECOND; /* 30 per second per IP */ - cfg->burst = RATELIMITER_PACKETS_BURSTABLE; /* Allow bursts of 5 at a time */ - cfg->gc_interval = 1000; /* same as expiration date */ - cfg->expire = 1000; /* Units of avg (seconds = 1) times 1000 */ - /* cfg->size and cfg->max are computed based on the memory size of left to zero */ + kmem_cache_free(entry_cache, container_of(rcu, struct entry, rcu)); + atomic_dec(&total_entries); } -int ratelimiter_init(struct ratelimiter *ratelimiter, struct wireguard_device *wg) +static void entry_uninit(struct entry *entry) { - struct net_device *dev = netdev_pub(wg); - struct xt_mtchk_param chk = { .net = wg->creating_net }; - int ret; - - memset(ratelimiter, 0, sizeof(struct ratelimiter)); - - cfg_init(&ratelimiter->v4_info.cfg, NFPROTO_IPV4); - memcpy(ratelimiter->v4_info.name, dev->name, IFNAMSIZ); - chk.matchinfo = &ratelimiter->v4_info; - chk.match = v4_match; - chk.family = NFPROTO_IPV4; - ret = v4_match->checkentry(&chk); - if (ret < 0) - return ret; - -#if IS_ENABLED(CONFIG_IPV6) - cfg_init(&ratelimiter->v6_info.cfg, NFPROTO_IPV6); - memcpy(ratelimiter->v6_info.name, dev->name, IFNAMSIZ); - chk.matchinfo = &ratelimiter->v6_info; - chk.match = v6_match; - chk.family = NFPROTO_IPV6; - ret = v6_match->checkentry(&chk); - if (ret < 0) { - struct xt_mtdtor_param dtor_v4 = { - .net = wg->creating_net, - .match = v4_match, - .matchinfo = &ratelimiter->v4_info, - .family = NFPROTO_IPV4 - }; - v4_match->destroy(&dtor_v4); - return ret; - } -#endif - - ratelimiter->net = wg->creating_net; - return 0; + hlist_del_rcu(&entry->hash); + call_rcu_bh(&entry->rcu, entry_free); } -void ratelimiter_uninit(struct ratelimiter *ratelimiter) +/* Calling this function with a NULL work uninits all entries. */ +static void gc_entries(struct work_struct *work) { - struct xt_mtdtor_param dtor = { .net = ratelimiter->net }; - - dtor.match = v4_match; - dtor.matchinfo = &ratelimiter->v4_info; - dtor.family = NFPROTO_IPV4; - v4_match->destroy(&dtor); - + unsigned int i; + struct entry *entry; + struct hlist_node *temp; + const u64 now = ktime_get_ns(); + + for (i = 0; i < table_size; ++i) { + spin_lock(&table_lock); + hlist_for_each_entry_safe (entry, temp, &table_v4[i], hash) { + if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC) + entry_uninit(entry); + } #if IS_ENABLED(CONFIG_IPV6) - dtor.match = v6_match; - dtor.matchinfo = &ratelimiter->v6_info; - dtor.family = NFPROTO_IPV6; - v6_match->destroy(&dtor); + hlist_for_each_entry_safe (entry, temp, &table_v6[i], hash) { + if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC) + entry_uninit(entry); + } #endif + spin_unlock(&table_lock); + if (likely(work)) + cond_resched(); + } + if (likely(work)) + queue_delayed_work(system_power_efficient_wq, &gc_work, HZ); } -bool ratelimiter_allow(struct ratelimiter *ratelimiter, struct sk_buff *skb) +bool ratelimiter_allow(struct sk_buff *skb, struct net *net) { - struct xt_action_param action = { { NULL } }; - if (unlikely(skb->len < sizeof(struct iphdr))) - return false; - if (ip_hdr(skb)->version == 4) { - action.match = v4_match; - action.matchinfo = &ratelimiter->v4_info; - action.thoff = ip_hdrlen(skb); + struct entry *entry; + struct hlist_head *bucket; + struct { u32 net; __be32 ip[3]; } data = { .net = (unsigned long)net & 0xffffffff }; + + if (skb->len >= sizeof(struct iphdr) && ip_hdr(skb)->version == 4) { + data.ip[0] = ip_hdr(skb)->saddr; + bucket = &table_v4[hsiphash(&data, sizeof(u32) * 2, &key) & (table_size - 1)]; } #if IS_ENABLED(CONFIG_IPV6) - else if (ip_hdr(skb)->version == 6) { - action.match = v6_match; - action.matchinfo = &ratelimiter->v6_info; + else if (skb->len >= sizeof(struct ipv6hdr) && ip_hdr(skb)->version == 6) { + memcpy(data.ip, &ipv6_hdr(skb)->saddr, sizeof(u32) * 3); /* Only 96 bits */ + bucket = &table_v6[hsiphash(&data, sizeof(u32) * 4, &key) & (table_size - 1)]; } #endif else return false; - return action.match->match(skb, &action); + rcu_read_lock(); + hlist_for_each_entry_rcu (entry, bucket, hash) { + if (entry->net == net && !memcmp(entry->ip, data.ip, sizeof(data.ip))) { + u64 now, tokens; + bool ret; + /* Inspired by nft_limit.c, but this is actually a slightly different + * algorithm. Namely, we incorporate the burst as part of the maximum + * tokens, rather than as part of the rate. */ + spin_lock(&entry->lock); + now = ktime_get_ns(); + tokens = min_t(u64, TOKEN_MAX, entry->tokens + now - entry->last_time_ns); + entry->last_time_ns = now; + ret = tokens >= PACKET_COST; + entry->tokens = ret ? tokens - PACKET_COST : tokens; + spin_unlock(&entry->lock); + rcu_read_unlock(); + return ret; + } + } + rcu_read_unlock(); + + if (atomic_inc_return(&total_entries) > max_entries) + goto err_oom; + + entry = kmem_cache_alloc(entry_cache, GFP_KERNEL); + if (!entry) + goto err_oom; + + entry->net = net; + memcpy(entry->ip, data.ip, sizeof(data.ip)); + INIT_HLIST_NODE(&entry->hash); + spin_lock_init(&entry->lock); + entry->last_time_ns = ktime_get_ns(); + entry->tokens = TOKEN_MAX - PACKET_COST; + spin_lock(&table_lock); + hlist_add_head_rcu(&entry->hash, bucket); + spin_unlock(&table_lock); + return true; + +err_oom: + atomic_dec(&total_entries); + return false; } -int ratelimiter_module_init(void) +int ratelimiter_init(void) { - v4_match = xt_request_find_match(NFPROTO_IPV4, "hashlimit", 1); - if (IS_ERR(v4_match)) { - pr_err("The xt_hashlimit module for IPv4 is required\n"); - return PTR_ERR(v4_match); - } + if (atomic64_inc_return(&refcnt) != 1) + return 0; + + entry_cache = kmem_cache_create("wireguard_ratelimiter", sizeof(struct entry), 0, 0, NULL); + if (!entry_cache) + goto err; + + /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting, + * but what it shares in common is that it uses a massive hashtable. So, + * we borrow their wisdom about good table sizes on different systems + * dependent on RAM. This calculation here comes from there. */ + table_size = (totalram_pages > (1 << 30) / PAGE_SIZE) ? 8192 : max_t(unsigned long, 16, roundup_pow_of_two((totalram_pages << PAGE_SHIFT) / (1 << 14) / sizeof(struct hlist_head))); + max_entries = table_size * 8; + + table_v4 = vmalloc(table_size * sizeof(struct hlist_head)); + if (!table_v4) + goto err_kmemcache; + __hash_init(table_v4, table_size); + #if IS_ENABLED(CONFIG_IPV6) - v6_match = xt_request_find_match(NFPROTO_IPV6, "hashlimit", 1); - if (IS_ERR(v6_match)) { - pr_err("The xt_hashlimit module for IPv6 is required\n"); - module_put(v4_match->me); - return PTR_ERR(v6_match); + table_v6 = vmalloc(table_size * sizeof(struct hlist_head)); + if (!table_v6) { + vfree(table_v4); + goto err_kmemcache; } + __hash_init(table_v6, table_size); #endif + + queue_delayed_work(system_power_efficient_wq, &gc_work, HZ); + get_random_bytes(&key, sizeof(key)); return 0; + +err_kmemcache: + kmem_cache_destroy(entry_cache); +err: + atomic64_dec(&refcnt); + return -ENOMEM; } -void ratelimiter_module_deinit(void) +void ratelimiter_uninit(void) { - module_put(v4_match->me); + if (atomic64_dec_return(&refcnt)) + return; + + cancel_delayed_work_sync(&gc_work); + gc_entries(NULL); + synchronize_rcu(); + vfree(table_v4); #if IS_ENABLED(CONFIG_IPV6) - module_put(v6_match->me); + vfree(table_v6); #endif + kmem_cache_destroy(entry_cache); } diff --git a/src/ratelimiter.h b/src/ratelimiter.h index c4dc9a7..fed73f7 100644 --- a/src/ratelimiter.h +++ b/src/ratelimiter.h @@ -3,24 +3,10 @@ #ifndef RATELIMITER_H #define RATELIMITER_H -#include <uapi/linux/netfilter/xt_hashlimit.h> +#include <linux/skbuff.h> -struct wireguard_device; -struct sk_buff; - -struct ratelimiter { - struct net *net; - struct xt_hashlimit_mtinfo1 v4_info; -#if IS_ENABLED(CONFIG_IPV6) - struct xt_hashlimit_mtinfo1 v6_info; -#endif -}; - -int ratelimiter_init(struct ratelimiter *ratelimiter, struct wireguard_device *wg); -void ratelimiter_uninit(struct ratelimiter *ratelimiter); -bool ratelimiter_allow(struct ratelimiter *ratelimiter, struct sk_buff *skb); - -int ratelimiter_module_init(void); -void ratelimiter_module_deinit(void); +int ratelimiter_init(void); +void ratelimiter_uninit(void); +bool ratelimiter_allow(struct sk_buff *skb, struct net *net); #endif diff --git a/src/tests/debug.mk b/src/tests/debug.mk index 9e34e53..5078a87 100644 --- a/src/tests/debug.mk +++ b/src/tests/debug.mk @@ -18,9 +18,7 @@ endif test: debug -sudo modprobe ip6_udp_tunnel -sudo modprobe udp_tunnel - -sudo modprobe x_tables -sudo modprobe ipv6 - -sudo modprobe xt_hashlimit -sudo modprobe nf_conntrack_ipv4 -sudo modprobe nf_conntrack_ipv6 -sudo rmmod wireguard diff --git a/src/tests/qemu/kernel.config b/src/tests/qemu/kernel.config index 4e4f573..5469448 100644 --- a/src/tests/qemu/kernel.config +++ b/src/tests/qemu/kernel.config @@ -9,6 +9,7 @@ CONFIG_NET_NS=y CONFIG_UNIX=y CONFIG_INET=y CONFIG_IPV6=y +CONFIG_NETFILTER=y CONFIG_NF_CONNTRACK=y CONFIG_NF_NAT=y CONFIG_NETFILTER_XTABLES=y |