summaryrefslogtreecommitdiffhomepage
path: root/src/ratelimiter.c
blob: 6bf85b0bbcbb091b4e5ae0576f3c220fceb0fd3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/* Copyright 2015-2016 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */

#include "wireguard.h"
#include "ratelimiter.h"
#include <linux/netfilter/x_tables.h>
#include <net/ip.h>
#include <linux/module.h>

enum {
	RATELIMITER_PACKETS_PER_SECOND = 75,
	RATELIMITER_PACKETS_BURSTABLE = 5
};

static inline void cfg_init(struct hashlimit_cfg1 *cfg, int family)
{
	memset(cfg, 0, sizeof(struct hashlimit_cfg1));
	if (family == NFPROTO_IPV4)
		cfg->srcmask = 32;
	else if (family == NFPROTO_IPV6)
		cfg->srcmask = 96;
	cfg->mode = XT_HASHLIMIT_HASH_SIP; /* source IP only -- we could also do source port by ORing this with XT_HASHLIMIT_HASH_SPT */
	cfg->avg = XT_HASHLIMIT_SCALE / RATELIMITER_PACKETS_PER_SECOND; /* 50 per second per IP */
	cfg->burst = RATELIMITER_PACKETS_BURSTABLE; /* Allow bursts of 5 at a time */
	cfg->gc_interval = 1000; /* same as expiration date */
	cfg->expire = 1000; /* Units of avg (seconds = 1) times 1000 */
	/* cfg->size and cfg->max are computed based on the memory size of left to zero */
}

int ratelimiter_init(struct ratelimiter *ratelimiter, struct wireguard_device *wg)
{
	struct net_device *dev = netdev_pub(wg);
	struct xt_mtchk_param chk = { .net = wg->creating_net };
	int ret;

	memset(ratelimiter, 0, sizeof(struct ratelimiter));

	cfg_init(&ratelimiter->v4_info.cfg, NFPROTO_IPV4);
	cfg_init(&ratelimiter->v6_info.cfg, NFPROTO_IPV6);
	memcpy(ratelimiter->v4_info.name, dev->name, IFNAMSIZ);
	memcpy(ratelimiter->v6_info.name, dev->name, IFNAMSIZ);

	ratelimiter->v4_match = xt_request_find_match(NFPROTO_IPV4, "hashlimit", 1);
	if (IS_ERR(ratelimiter->v4_match)) {
		pr_err("The xt_hashlimit module is required");
		return PTR_ERR(ratelimiter->v4_match);
	}

	chk.matchinfo = &ratelimiter->v4_info;
	chk.match = ratelimiter->v4_match;
	chk.family = NFPROTO_IPV4;
	ret = ratelimiter->v4_match->checkentry(&chk);
	if (ret < 0) {
		module_put(ratelimiter->v4_match->me);
		return ret;
	}

	ratelimiter->v6_match = xt_request_find_match(NFPROTO_IPV6, "hashlimit", 1);
	if (IS_ERR(ratelimiter->v6_match)) {
		pr_err("The xt_hashlimit module is required");
		module_put(ratelimiter->v4_match->me);
		return PTR_ERR(ratelimiter->v6_match);
	}

	chk.matchinfo = &ratelimiter->v6_info;
	chk.match = ratelimiter->v6_match;
	chk.family = NFPROTO_IPV6;
	ret = ratelimiter->v6_match->checkentry(&chk);
	if (ret < 0) {
		struct xt_mtdtor_param dtor_v4 = {
			.net = wg->creating_net,
			.match = ratelimiter->v4_match,
			.matchinfo = &ratelimiter->v4_info,
			.family = NFPROTO_IPV4
		};
		ratelimiter->v4_match->destroy(&dtor_v4);
		module_put(ratelimiter->v4_match->me);
		module_put(ratelimiter->v6_match->me);
		return ret;
	}

	ratelimiter->net = wg->creating_net;
	return 0;
}

void ratelimiter_uninit(struct ratelimiter *ratelimiter)
{
	struct xt_mtdtor_param dtor = { .net = ratelimiter->net };

	dtor.match = ratelimiter->v4_match;
	dtor.matchinfo = &ratelimiter->v4_info;
	dtor.family = NFPROTO_IPV4;
	ratelimiter->v4_match->destroy(&dtor);
	module_put(ratelimiter->v4_match->me);

	dtor.match = ratelimiter->v6_match;
	dtor.matchinfo = &ratelimiter->v6_info;
	dtor.family = NFPROTO_IPV6;
	ratelimiter->v6_match->destroy(&dtor);
	module_put(ratelimiter->v6_match->me);
}

bool ratelimiter_allow(struct ratelimiter *ratelimiter, struct sk_buff *skb)
{
	struct xt_action_param action = { { NULL } };
	if (unlikely(skb->len < sizeof(struct iphdr)))
		return false;
	if (ip_hdr(skb)->version == 4) {
		action.match = ratelimiter->v4_match;
		action.matchinfo = &ratelimiter->v4_info;
		action.thoff = ip_hdrlen(skb);
		action.family = NFPROTO_IPV4;
	} else if (ip_hdr(skb)->version == 6) {
		action.match = ratelimiter->v6_match;
		action.matchinfo = &ratelimiter->v6_info;
		action.family = NFPROTO_IPV6;
	} else
		return false;
	return action.match->match(skb, &action);
}