summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/device.c6
-rw-r--r--src/netlink.c8
-rw-r--r--src/socket.c64
-rw-r--r--src/socket.h4
4 files changed, 35 insertions, 47 deletions
diff --git a/src/device.c b/src/device.c
index ffc36b4..12c11fb 100644
--- a/src/device.c
+++ b/src/device.c
@@ -51,7 +51,7 @@ static int open(struct net_device *dev)
#endif
#endif
- ret = socket_init(wg);
+ ret = socket_init(wg, wg->incoming_port);
if (ret < 0)
return ret;
mutex_lock(&wg->device_update_lock);
@@ -105,7 +105,7 @@ static int stop(struct net_device *dev)
}
mutex_unlock(&wg->device_update_lock);
skb_queue_purge(&wg->incoming_handshakes);
- socket_uninit(wg);
+ socket_reinit(wg, NULL, NULL);
return 0;
}
@@ -220,7 +220,7 @@ static void destruct(struct net_device *dev)
ratelimiter_uninit();
memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity));
skb_queue_purge(&wg->incoming_handshakes);
- socket_uninit(wg);
+ socket_reinit(wg, NULL, NULL);
free_percpu(dev->tstats);
free_percpu(wg->incoming_handshakes_worker);
if (wg->have_creating_net_ref)
diff --git a/src/netlink.c b/src/netlink.c
index f2e724a..9297e60 100644
--- a/src/netlink.c
+++ b/src/netlink.c
@@ -260,13 +260,13 @@ static int set_port(struct wireguard_device *wg, u16 port)
if (wg->incoming_port == port)
return 0;
- socket_uninit(wg);
- wg->incoming_port = port;
list_for_each_entry(peer, &wg->peer_list, peer_list)
socket_clear_peer_endpoint_src(peer);
- if (!netif_running(wg->dev))
+ if (!netif_running(wg->dev)) {
+ wg->incoming_port = port;
return 0;
- return socket_init(wg);
+ }
+ return socket_init(wg, port);
}
static int set_allowedip(struct wireguard_peer *peer, struct nlattr **attrs)
diff --git a/src/socket.c b/src/socket.c
index 1ce74cd..5bf5a92 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -20,7 +20,6 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
.saddr = endpoint->src4.s_addr,
.daddr = endpoint->addr4.sin_addr.s_addr,
.fl4_dport = endpoint->addr4.sin_port,
- .fl4_sport = htons(wg->incoming_port),
.flowi4_mark = wg->fwmark,
.flowi4_proto = IPPROTO_UDP
};
@@ -34,6 +33,7 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
rcu_read_lock_bh();
sock = rcu_dereference_bh(wg->sock4);
+ fl.fl4_sport = inet_sk(sock)->inet_sport;
if (unlikely(!sock)) {
ret = -ENONET;
@@ -89,7 +89,6 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct
.saddr = endpoint->src6,
.daddr = endpoint->addr6.sin6_addr,
.fl6_dport = endpoint->addr6.sin6_port,
- .fl6_sport = htons(wg->incoming_port),
.flowi6_mark = wg->fwmark,
.flowi6_oif = endpoint->addr6.sin6_scope_id,
.flowi6_proto = IPPROTO_UDP
@@ -105,6 +104,7 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct
rcu_read_lock_bh();
sock = rcu_dereference_bh(wg->sock6);
+ fl.fl6_sport = inet_sk(sock)->inet_sport;
if (unlikely(!sock)) {
ret = -ENONET;
@@ -309,87 +309,75 @@ static inline void set_sock_opts(struct socket *sock)
sk_set_memalloc(sock->sk);
}
-int socket_init(struct wireguard_device *wg)
+int socket_init(struct wireguard_device *wg, u16 port)
{
- int ret = 0;
+ int ret;
struct udp_tunnel_sock_cfg cfg = {
.sk_user_data = wg,
.encap_type = 1,
.encap_rcv = receive
};
- struct socket *new4 = NULL;
+ struct socket *new4 = NULL, *new6 = NULL;
struct udp_port_cfg port4 = {
.family = AF_INET,
.local_ip.s_addr = htonl(INADDR_ANY),
- .local_udp_port = htons(wg->incoming_port),
+ .local_udp_port = htons(port),
.use_udp_checksums = true
};
#if IS_ENABLED(CONFIG_IPV6)
int retries = 0;
- struct socket *new6 = NULL;
struct udp_port_cfg port6 = {
.family = AF_INET6,
.local_ip6 = IN6ADDR_ANY_INIT,
- .local_udp_port = htons(wg->incoming_port),
.use_udp6_tx_checksums = true,
.use_udp6_rx_checksums = true,
.ipv6_v6only = true
};
#endif
- mutex_lock(&wg->socket_update_lock);
#if IS_ENABLED(CONFIG_IPV6)
retry:
#endif
- if (rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock)) || rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock))) {
- ret = -EADDRINUSE;
- goto out;
- }
ret = udp_sock_create(wg->creating_net, &port4, &new4);
if (ret < 0) {
pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
- goto out;
+ return ret;
}
- wg->incoming_port = ntohs(inet_sk(new4->sk)->inet_sport);
set_sock_opts(new4);
setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
- rcu_assign_pointer(wg->sock4, new4->sk);
#if IS_ENABLED(CONFIG_IPV6)
- if (!ipv6_mod_enabled())
- goto out;
- port6.local_udp_port = htons(wg->incoming_port);
- ret = udp_sock_create(wg->creating_net, &port6, &new6);
- if (ret < 0) {
- udp_tunnel_sock_release(new4);
- rcu_assign_pointer(wg->sock4, NULL);
- if (ret == -EADDRINUSE && !port4.local_udp_port && retries++ < 100)
- goto retry;
- if (!port4.local_udp_port)
- wg->incoming_port = 0;
- pr_err("%s: Could not create IPv6 socket\n", wg->dev->name);
- goto out;
+ if (ipv6_mod_enabled()) {
+ port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
+ ret = udp_sock_create(wg->creating_net, &port6, &new6);
+ if (ret < 0) {
+ udp_tunnel_sock_release(new4);
+ if (ret == -EADDRINUSE && !port && retries++ < 100)
+ goto retry;
+ pr_err("%s: Could not create IPv6 socket\n", wg->dev->name);
+ return ret;
+ }
+ set_sock_opts(new6);
+ setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
}
- set_sock_opts(new6);
- setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
- rcu_assign_pointer(wg->sock6, new6->sk);
#endif
-out:
- mutex_unlock(&wg->socket_update_lock);
- return ret;
+ socket_reinit(wg, new4 ? new4->sk : NULL, new6 ? new6->sk : NULL);
+ return 0;
}
-void socket_uninit(struct wireguard_device *wg)
+void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6)
{
struct sock *old4, *old6;
mutex_lock(&wg->socket_update_lock);
old4 = rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock));
old6 = rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock));
- rcu_assign_pointer(wg->sock4, NULL);
- rcu_assign_pointer(wg->sock6, NULL);
+ rcu_assign_pointer(wg->sock4, new4);
+ rcu_assign_pointer(wg->sock6, new6);
+ if (new4)
+ wg->incoming_port = ntohs(inet_sk(new4)->inet_sport);
mutex_unlock(&wg->socket_update_lock);
synchronize_rcu_bh();
synchronize_net();
diff --git a/src/socket.h b/src/socket.h
index 843d544..161bd0b 100644
--- a/src/socket.h
+++ b/src/socket.h
@@ -8,8 +8,8 @@
#include <linux/if_vlan.h>
#include <linux/if_ether.h>
-int socket_init(struct wireguard_device *wg);
-void socket_uninit(struct wireguard_device *wg);
+int socket_init(struct wireguard_device *wg, u16 port);
+void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6);
int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, size_t len, u8 ds);
int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds);
int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *out_buffer, size_t len);