diff options
-rw-r--r-- | src/device.c | 2 | ||||
-rw-r--r-- | src/netlink.c | 24 | ||||
-rw-r--r-- | src/socket.c | 22 | ||||
-rw-r--r-- | src/socket.h | 2 |
4 files changed, 25 insertions, 25 deletions
diff --git a/src/device.c b/src/device.c index b8c2390..3bbd0e8 100644 --- a/src/device.c +++ b/src/device.c @@ -54,7 +54,7 @@ static int wg_open(struct net_device *dev) #endif mutex_lock(&wg->device_update_lock); - ret = wg_socket_init(wg, wg->incoming_port); + ret = wg_socket_init(wg, wg->incoming_port, wg->socketdev_index); if (ret < 0) goto out; list_for_each_entry(peer, &wg->peer_list, peer_list) { diff --git a/src/netlink.c b/src/netlink.c index 573df1a..3021424 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -311,19 +311,20 @@ static int wg_get_device_done(struct netlink_callback *cb) return 0; } -static int set_port(struct wg_device *wg, u16 port) +static int set_port(struct wg_device *wg, u16 port, u32 dev) { struct wg_peer *peer; - if (wg->incoming_port == port) + if (wg->incoming_port == port && wg->socketdev_index == dev) return 0; list_for_each_entry(peer, &wg->peer_list, peer_list) wg_socket_clear_peer_endpoint_src(peer); if (!netif_running(wg->dev)) { wg->incoming_port = port; + wg->socketdev_index = dev; return 0; } - return wg_socket_init(wg, port); + return wg_socket_init(wg, port, dev); } static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) @@ -531,16 +532,21 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) wg_socket_clear_peer_endpoint_src(peer); } - if (info->attrs[WGDEVICE_A_LISTEN_PORT]) { - ret = set_port(wg, - nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT])); + if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_SOCKETDEV_INDEX]) { + u16 port = wg->incoming_port; + u32 dev = wg->socketdev_index; + + if (info->attrs[WGDEVICE_A_LISTEN_PORT]) + port = nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]); + + if (info->attrs[WGDEVICE_A_SOCKETDEV_INDEX]) + dev = nla_get_u32(info->attrs[WGDEVICE_A_SOCKETDEV_INDEX]); + + ret = set_port(wg, port, dev); if (ret) goto out; } - if (info->attrs[WGDEVICE_A_SOCKETDEV_INDEX]) { - wg->socketdev_index = nla_get_u32(info->attrs[WGDEVICE_A_SOCKETDEV_INDEX]); - } if (flags & WGDEVICE_F_REPLACE_PEERS) wg_peer_remove_all(wg); diff --git a/src/socket.c b/src/socket.c index 8ef44c1..75780b0 100644 --- a/src/socket.c +++ b/src/socket.c @@ -345,7 +345,7 @@ static void set_sock_opts(struct socket *sock) sk_set_memalloc(sock->sk); } -int wg_socket_init(struct wg_device *wg, u16 port) +int wg_socket_init(struct wg_device *wg, u16 port, u32 socketdev) { struct net *net; int ret; @@ -359,7 +359,8 @@ int wg_socket_init(struct wg_device *wg, u16 port) .family = AF_INET, .local_ip.s_addr = htonl(INADDR_ANY), .local_udp_port = htons(port), - .use_udp_checksums = true + .use_udp_checksums = true, + .bind_ifindex = socketdev, }; #if IS_ENABLED(CONFIG_IPV6) int retries = 0; @@ -368,20 +369,11 @@ int wg_socket_init(struct wg_device *wg, u16 port) .local_ip6 = IN6ADDR_ANY_INIT, .use_udp6_tx_checksums = true, .use_udp6_rx_checksums = true, - .ipv6_v6only = true + .ipv6_v6only = true, + .bind_ifindex = socketdev, }; - if (wg->socketdev_index > 0) { - port6.bind_ifindex = wg->socketdev_index; - } else { - port6.bind_ifindex = 0; - } #endif - if (wg->socketdev_index > 0) { - port4.bind_ifindex = wg->socketdev_index; - } else { - port4.bind_ifindex = 0; - } rcu_read_lock(); net = rcu_dereference(wg->creating_net); net = net ? maybe_get_net(net) : NULL; @@ -437,8 +429,10 @@ void wg_socket_reinit(struct wg_device *wg, struct sock *new4, lockdep_is_held(&wg->socket_update_lock)); rcu_assign_pointer(wg->sock4, new4); rcu_assign_pointer(wg->sock6, new6); - if (new4) + if (new4) { wg->incoming_port = ntohs(inet_sk(new4)->inet_sport); + wg->socketdev_index = new4->sk_bound_dev_if; + } mutex_unlock(&wg->socket_update_lock); synchronize_net(); sock_free(old4); diff --git a/src/socket.h b/src/socket.h index bab5848..f18308d 100644 --- a/src/socket.h +++ b/src/socket.h @@ -11,7 +11,7 @@ #include <linux/if_vlan.h> #include <linux/if_ether.h> -int wg_socket_init(struct wg_device *wg, u16 port); +int wg_socket_init(struct wg_device *wg, u16 port, u32 socketdev); void wg_socket_reinit(struct wg_device *wg, struct sock *new4, struct sock *new6); int wg_socket_send_buffer_to_peer(struct wg_peer *peer, void *data, |