diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/device.c | 6 | ||||
-rw-r--r-- | src/netlink.c | 8 | ||||
-rw-r--r-- | src/socket.c | 64 | ||||
-rw-r--r-- | src/socket.h | 4 |
4 files changed, 35 insertions, 47 deletions
diff --git a/src/device.c b/src/device.c index ffc36b4..12c11fb 100644 --- a/src/device.c +++ b/src/device.c @@ -51,7 +51,7 @@ static int open(struct net_device *dev) #endif #endif - ret = socket_init(wg); + ret = socket_init(wg, wg->incoming_port); if (ret < 0) return ret; mutex_lock(&wg->device_update_lock); @@ -105,7 +105,7 @@ static int stop(struct net_device *dev) } mutex_unlock(&wg->device_update_lock); skb_queue_purge(&wg->incoming_handshakes); - socket_uninit(wg); + socket_reinit(wg, NULL, NULL); return 0; } @@ -220,7 +220,7 @@ static void destruct(struct net_device *dev) ratelimiter_uninit(); memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity)); skb_queue_purge(&wg->incoming_handshakes); - socket_uninit(wg); + socket_reinit(wg, NULL, NULL); free_percpu(dev->tstats); free_percpu(wg->incoming_handshakes_worker); if (wg->have_creating_net_ref) diff --git a/src/netlink.c b/src/netlink.c index f2e724a..9297e60 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -260,13 +260,13 @@ static int set_port(struct wireguard_device *wg, u16 port) if (wg->incoming_port == port) return 0; - socket_uninit(wg); - wg->incoming_port = port; list_for_each_entry(peer, &wg->peer_list, peer_list) socket_clear_peer_endpoint_src(peer); - if (!netif_running(wg->dev)) + if (!netif_running(wg->dev)) { + wg->incoming_port = port; return 0; - return socket_init(wg); + } + return socket_init(wg, port); } static int set_allowedip(struct wireguard_peer *peer, struct nlattr **attrs) diff --git a/src/socket.c b/src/socket.c index 1ce74cd..5bf5a92 100644 --- a/src/socket.c +++ b/src/socket.c @@ -20,7 +20,6 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct .saddr = endpoint->src4.s_addr, .daddr = endpoint->addr4.sin_addr.s_addr, .fl4_dport = endpoint->addr4.sin_port, - .fl4_sport = htons(wg->incoming_port), .flowi4_mark = wg->fwmark, .flowi4_proto = IPPROTO_UDP }; @@ -34,6 +33,7 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct rcu_read_lock_bh(); sock = rcu_dereference_bh(wg->sock4); + fl.fl4_sport = inet_sk(sock)->inet_sport; if (unlikely(!sock)) { ret = -ENONET; @@ -89,7 +89,6 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct .saddr = endpoint->src6, .daddr = endpoint->addr6.sin6_addr, .fl6_dport = endpoint->addr6.sin6_port, - .fl6_sport = htons(wg->incoming_port), .flowi6_mark = wg->fwmark, .flowi6_oif = endpoint->addr6.sin6_scope_id, .flowi6_proto = IPPROTO_UDP @@ -105,6 +104,7 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct rcu_read_lock_bh(); sock = rcu_dereference_bh(wg->sock6); + fl.fl6_sport = inet_sk(sock)->inet_sport; if (unlikely(!sock)) { ret = -ENONET; @@ -309,87 +309,75 @@ static inline void set_sock_opts(struct socket *sock) sk_set_memalloc(sock->sk); } -int socket_init(struct wireguard_device *wg) +int socket_init(struct wireguard_device *wg, u16 port) { - int ret = 0; + int ret; struct udp_tunnel_sock_cfg cfg = { .sk_user_data = wg, .encap_type = 1, .encap_rcv = receive }; - struct socket *new4 = NULL; + struct socket *new4 = NULL, *new6 = NULL; struct udp_port_cfg port4 = { .family = AF_INET, .local_ip.s_addr = htonl(INADDR_ANY), - .local_udp_port = htons(wg->incoming_port), + .local_udp_port = htons(port), .use_udp_checksums = true }; #if IS_ENABLED(CONFIG_IPV6) int retries = 0; - struct socket *new6 = NULL; struct udp_port_cfg port6 = { .family = AF_INET6, .local_ip6 = IN6ADDR_ANY_INIT, - .local_udp_port = htons(wg->incoming_port), .use_udp6_tx_checksums = true, .use_udp6_rx_checksums = true, .ipv6_v6only = true }; #endif - mutex_lock(&wg->socket_update_lock); #if IS_ENABLED(CONFIG_IPV6) retry: #endif - if (rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock)) || rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock))) { - ret = -EADDRINUSE; - goto out; - } ret = udp_sock_create(wg->creating_net, &port4, &new4); if (ret < 0) { pr_err("%s: Could not create IPv4 socket\n", wg->dev->name); - goto out; + return ret; } - wg->incoming_port = ntohs(inet_sk(new4->sk)->inet_sport); set_sock_opts(new4); setup_udp_tunnel_sock(wg->creating_net, new4, &cfg); - rcu_assign_pointer(wg->sock4, new4->sk); #if IS_ENABLED(CONFIG_IPV6) - if (!ipv6_mod_enabled()) - goto out; - port6.local_udp_port = htons(wg->incoming_port); - ret = udp_sock_create(wg->creating_net, &port6, &new6); - if (ret < 0) { - udp_tunnel_sock_release(new4); - rcu_assign_pointer(wg->sock4, NULL); - if (ret == -EADDRINUSE && !port4.local_udp_port && retries++ < 100) - goto retry; - if (!port4.local_udp_port) - wg->incoming_port = 0; - pr_err("%s: Could not create IPv6 socket\n", wg->dev->name); - goto out; + if (ipv6_mod_enabled()) { + port6.local_udp_port = inet_sk(new4->sk)->inet_sport; + ret = udp_sock_create(wg->creating_net, &port6, &new6); + if (ret < 0) { + udp_tunnel_sock_release(new4); + if (ret == -EADDRINUSE && !port && retries++ < 100) + goto retry; + pr_err("%s: Could not create IPv6 socket\n", wg->dev->name); + return ret; + } + set_sock_opts(new6); + setup_udp_tunnel_sock(wg->creating_net, new6, &cfg); } - set_sock_opts(new6); - setup_udp_tunnel_sock(wg->creating_net, new6, &cfg); - rcu_assign_pointer(wg->sock6, new6->sk); #endif -out: - mutex_unlock(&wg->socket_update_lock); - return ret; + socket_reinit(wg, new4 ? new4->sk : NULL, new6 ? new6->sk : NULL); + return 0; } -void socket_uninit(struct wireguard_device *wg) +void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6) { struct sock *old4, *old6; mutex_lock(&wg->socket_update_lock); old4 = rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock)); old6 = rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock)); - rcu_assign_pointer(wg->sock4, NULL); - rcu_assign_pointer(wg->sock6, NULL); + rcu_assign_pointer(wg->sock4, new4); + rcu_assign_pointer(wg->sock6, new6); + if (new4) + wg->incoming_port = ntohs(inet_sk(new4)->inet_sport); mutex_unlock(&wg->socket_update_lock); synchronize_rcu_bh(); synchronize_net(); diff --git a/src/socket.h b/src/socket.h index 843d544..161bd0b 100644 --- a/src/socket.h +++ b/src/socket.h @@ -8,8 +8,8 @@ #include <linux/if_vlan.h> #include <linux/if_ether.h> -int socket_init(struct wireguard_device *wg); -void socket_uninit(struct wireguard_device *wg); +int socket_init(struct wireguard_device *wg, u16 port); +void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6); int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, size_t len, u8 ds); int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds); int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *out_buffer, size_t len); |