diff options
Diffstat (limited to 'src/netlink.c')
-rw-r--r-- | src/netlink.c | 61 |
1 files changed, 35 insertions, 26 deletions
diff --git a/src/netlink.c b/src/netlink.c index f44f211..b179b31 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -69,9 +69,9 @@ static struct wg_device *lookup_interface(struct nlattr **attrs, return netdev_priv(dev); } -static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) +static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr, + int family) { - struct sk_buff *skb = ctx; struct nlattr *allowedip_nest; allowedip_nest = nla_nest_start(skb, 0); @@ -90,10 +90,12 @@ static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) return 0; } -static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, - struct sk_buff *skb) +static int +get_peer(struct wg_peer *peer, struct allowedips_node **next_allowedips_node, + u64 *allowedips_seq, struct sk_buff *skb) { struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0); + struct allowedips_node *allowedips_node = *next_allowedips_node; bool fail; if (!peer_nest) @@ -106,7 +108,7 @@ static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, if (fail) goto err; - if (!rt_cursor->seq) { + if (!allowedips_node) { const struct __kernel_timespec last_handshake = { .tv_sec = peer->walltime_last_handshake.tv_sec, .tv_nsec = peer->walltime_last_handshake.tv_nsec @@ -143,21 +145,39 @@ static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, read_unlock_bh(&peer->endpoint_lock); if (fail) goto err; + allowedips_node = + list_first_entry_or_null(&peer->allowedips_list, + struct allowedips_node, peer_list); } + if (!allowedips_node) + goto no_allowedips; + if (!*allowedips_seq) + *allowedips_seq = peer->device->peer_allowedips.seq; + else if (*allowedips_seq != peer->device->peer_allowedips.seq) + goto no_allowedips; allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); if (!allowedips_nest) goto err; - if (wg_allowedips_walk_by_peer(&peer->device->peer_allowedips, - rt_cursor, peer, get_allowedips, skb, - &peer->device->device_update_lock)) { - nla_nest_end(skb, allowedips_nest); - nla_nest_end(skb, peer_nest); - return -EMSGSIZE; + + list_for_each_entry_from(allowedips_node, &peer->allowedips_list, + peer_list) { + u8 cidr, ip[16] __aligned(__alignof(u64)); + int family; + + family = wg_allowedips_read_node(allowedips_node, ip, &cidr); + if (get_allowedips(skb, ip, cidr, family)) { + nla_nest_end(skb, allowedips_nest); + nla_nest_end(skb, peer_nest); + *next_allowedips_node = allowedips_node; + return -EMSGSIZE; + } } - memset(rt_cursor, 0, sizeof(*rt_cursor)); nla_nest_end(skb, allowedips_nest); +no_allowedips: nla_nest_end(skb, peer_nest); + *next_allowedips_node = NULL; + *allowedips_seq = 0; return 0; err: nla_nest_cancel(skb, peer_nest); @@ -174,16 +194,9 @@ static int wg_get_device_start(struct netlink_callback *cb) genl_family.maxattr, device_policy, NULL); if (ret < 0) return ret; - cb->args[2] = (long)kzalloc(sizeof(struct allowedips_cursor), - GFP_KERNEL); - if (unlikely(!cb->args[2])) - return -ENOMEM; wg = lookup_interface(attrs, cb->skb); - if (IS_ERR(wg)) { - kfree((void *)cb->args[2]); - cb->args[2] = 0; + if (IS_ERR(wg)) return PTR_ERR(wg); - } cb->args[0] = (long)wg; return 0; } @@ -191,7 +204,6 @@ static int wg_get_device_start(struct netlink_callback *cb) static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) { struct wg_peer *peer, *next_peer_cursor, *last_peer_cursor; - struct allowedips_cursor *rt_cursor; struct nlattr *peers_nest; struct wg_device *wg; int ret = -EMSGSIZE; @@ -201,7 +213,6 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) wg = (struct wg_device *)cb->args[0]; next_peer_cursor = (struct wg_peer *)cb->args[1]; last_peer_cursor = (struct wg_peer *)cb->args[1]; - rt_cursor = (struct allowedips_cursor *)cb->args[2]; rtnl_lock(); mutex_lock(&wg->device_update_lock); @@ -253,7 +264,8 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list); list_for_each_entry_continue(peer, &wg->peer_list, peer_list) { - if (get_peer(peer, rt_cursor, skb)) { + if (get_peer(peer, (struct allowedips_node **)&cb->args[2], + (u64 *)&cb->args[4] /* and args[5] */, skb)) { done = false; break; } @@ -290,12 +302,9 @@ static int wg_get_device_done(struct netlink_callback *cb) { struct wg_device *wg = (struct wg_device *)cb->args[0]; struct wg_peer *peer = (struct wg_peer *)cb->args[1]; - struct allowedips_cursor *rt_cursor = - (struct allowedips_cursor *)cb->args[2]; if (wg) dev_put(wg->dev); - kfree(rt_cursor); wg_peer_put(peer); return 0; } |