diff options
Diffstat (limited to 'src/netlink.c')
-rw-r--r-- | src/netlink.c | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/src/netlink.c b/src/netlink.c index fc27b7f..e727669 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -63,22 +63,20 @@ static struct wireguard_device *lookup_interface(struct nlattr **attrs, struct s struct allowedips_ctx { struct sk_buff *skb; - unsigned int idx_cursor, idx; + unsigned int i; }; -static int get_allowedips(void *ctx, union nf_inet_addr ip, u8 cidr, int family) +static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) { struct nlattr *allowedip_nest; struct allowedips_ctx *actx = ctx; - if (++actx->idx < actx->idx_cursor) - return 0; - allowedip_nest = nla_nest_start(actx->skb, actx->idx - 1); + allowedip_nest = nla_nest_start(actx->skb, actx->i++); if (!allowedip_nest) return -EMSGSIZE; if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) || nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) || - nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), &ip)) { + nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) { nla_nest_cancel(actx->skb, allowedip_nest); return -EMSGSIZE; } @@ -87,9 +85,9 @@ static int get_allowedips(void *ctx, union nf_inet_addr ip, u8 cidr, int family) return 0; } -static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned int *allowedips_idx_cursor, struct sk_buff *skb) +static int get_peer(struct wireguard_peer *peer, unsigned int index, struct routing_table_cursor *rt_cursor, struct sk_buff *skb) { - struct allowedips_ctx ctx = { .skb = skb, .idx_cursor = *allowedips_idx_cursor }; + struct allowedips_ctx ctx = { .skb = skb }; struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, index); bool fail; @@ -102,7 +100,7 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned in if (fail) goto err; - if (!ctx.idx_cursor) { + if (!rt_cursor->seq) { down_read(&peer->handshake.lock); fail = nla_put(skb, WGPEER_A_PRESHARED_KEY, NOISE_SYMMETRIC_KEY_LEN, peer->handshake.preshared_key); up_read(&peer->handshake.lock); @@ -126,13 +124,12 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned in allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); if (!allowedips_nest) goto err; - if (routing_table_walk_ips_by_peer(&peer->device->peer_routing_table, &ctx, peer, get_allowedips, &peer->device->device_update_lock)) { - *allowedips_idx_cursor = ctx.idx; + if (routing_table_walk_by_peer(&peer->device->peer_routing_table, rt_cursor, peer, get_allowedips, &ctx, &peer->device->device_update_lock)) { nla_nest_end(skb, allowedips_nest); nla_nest_end(skb, peer_nest); return -EMSGSIZE; } - *allowedips_idx_cursor = 0; + memset(rt_cursor, 0, sizeof(*rt_cursor)); nla_nest_end(skb, allowedips_nest); nla_nest_end(skb, peer_nest); return 0; @@ -149,9 +146,15 @@ static int get_device_start(struct netlink_callback *cb) if (ret < 0) return ret; + cb->args[2] = (long)kzalloc(sizeof(struct routing_table_cursor), GFP_KERNEL); + if (!cb->args[2]) + return -ENOMEM; wg = lookup_interface(attrs, cb->skb); - if (IS_ERR(wg)) + if (IS_ERR(wg)) { + kfree((void *)cb->args[2]); + cb->args[2] = 0; return PTR_ERR(wg); + } cb->args[0] = (long)wg; return 0; } @@ -160,7 +163,8 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) { struct wireguard_device *wg = (struct wireguard_device *)cb->args[0]; struct wireguard_peer *peer, *next_peer_cursor = NULL, *last_peer_cursor = (struct wireguard_peer *)cb->args[1]; - unsigned int peer_idx = 0, allowedips_idx_cursor = (unsigned int)cb->args[2]; + struct routing_table_cursor *rt_cursor = (struct routing_table_cursor *)cb->args[2]; + unsigned int peer_idx = 0; struct nlattr *peers_nest; bool done = true; void *hdr; @@ -203,7 +207,7 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list); list_for_each_entry_continue (peer, &wg->peer_list, peer_list) { - if (get_peer(peer, peer_idx++, &allowedips_idx_cursor, skb)) { + if (get_peer(peer, peer_idx++, rt_cursor, skb)) { done = false; break; } @@ -228,7 +232,6 @@ out: return 0; } cb->args[1] = (long)next_peer_cursor; - cb->args[2] = (long)allowedips_idx_cursor; return skb->len; /* At this point, we can't really deal ourselves with safely zeroing out @@ -240,9 +243,11 @@ static int get_device_done(struct netlink_callback *cb) { struct wireguard_device *wg = (struct wireguard_device *)cb->args[0]; struct wireguard_peer *peer = (struct wireguard_peer *)cb->args[1]; + struct routing_table_cursor *rt_cursor = (struct routing_table_cursor *)cb->args[2]; if (wg) dev_put(wg->dev); + kfree(rt_cursor); peer_put(peer); return 0; } |