diff options
-rw-r--r-- | src/config.c | 2 | ||||
-rw-r--r-- | src/device.c | 1 | ||||
-rw-r--r-- | src/peer.c | 8 | ||||
-rw-r--r-- | src/peer.h | 7 | ||||
-rw-r--r-- | src/socket.c | 151 | ||||
-rw-r--r-- | src/socket.h | 1 |
6 files changed, 68 insertions, 102 deletions
diff --git a/src/config.c b/src/config.c index 7f217ff..a98ecd9 100644 --- a/src/config.c +++ b/src/config.c @@ -11,7 +11,7 @@ static int set_peer_dst(struct wireguard_peer *peer, void *data) { - socket_set_peer_dst(peer); + dst_cache_reset(&peer->endpoint_cache); return 0; } diff --git a/src/device.c b/src/device.c index dad5521..41df285 100644 --- a/src/device.c +++ b/src/device.c @@ -37,7 +37,6 @@ static void uninit(struct net_device *dev) static int open_peer(struct wireguard_peer *peer, void *data) { - socket_set_peer_dst(peer); timers_init_peer(peer); packet_send_queue(peer); if (peer->persistent_keepalive_interval) @@ -26,6 +26,11 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ if (!peer) return NULL; + if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)) { + kfree(peer); + return NULL; + } + peer->internal_id = atomic64_inc_return(&peer_counter); peer->device = wg; cookie_init(&peer->latest_cookie); @@ -82,8 +87,7 @@ static void rcu_release(struct rcu_head *rcu) struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu); pr_debug("Peer %Lu (%pISpfsc) destroyed\n", peer->internal_id, &peer->endpoint_addr); skb_queue_purge(&peer->tx_packet_queue); - if (peer->endpoint_dst) - dst_release(peer->endpoint_dst); + dst_cache_destroy(&peer->endpoint_cache); kzfree(peer); } @@ -10,17 +10,14 @@ #include <linux/netfilter.h> #include <linux/spinlock.h> #include <linux/kref.h> +#include <net/dst_cache.h> struct wireguard_device; struct wireguard_peer { struct wireguard_device *device; struct sockaddr_storage endpoint_addr; - struct dst_entry *endpoint_dst; - union { - struct flowi4 fl4; - struct flowi6 fl6; - } endpoint_flow; + struct dst_cache endpoint_cache; rwlock_t endpoint_lock; struct noise_handshake handshake; struct noise_keypairs keypairs; diff --git a/src/socket.c b/src/socket.c index fc8863d..265c6dc 100644 --- a/src/socket.c +++ b/src/socket.c @@ -10,9 +10,16 @@ #include <linux/net.h> #include <linux/if_vlan.h> #include <linux/if_ether.h> +#include <net/dst_cache.h> #include <net/udp_tunnel.h> #include <net/ipv6.h> + +union flowi46 { + struct flowi4 fl4; + struct flowi6 fl6; +}; + int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb) { struct iphdr *ip4; @@ -41,10 +48,8 @@ int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb) return 0; } -static inline struct dst_entry *route(struct wireguard_device *wg, struct flowi4 *fl4, struct flowi6 *fl6, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6) +static inline struct dst_entry *route(struct wireguard_device *wg, union flowi46 *fl, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, struct dst_cache *cache) { - struct dst_entry *dst = ERR_PTR(-EAFNOSUPPORT); - if (addr->ss_family == AF_INET) { struct rtable *rt; struct sockaddr_in *sin4 = (struct sockaddr_in *)addr; @@ -52,43 +57,55 @@ static inline struct dst_entry *route(struct wireguard_device *wg, struct flowi4 if (unlikely(!sock4)) return ERR_PTR(-ENONET); - memset(fl4, 0, sizeof(struct flowi4)); - fl4->daddr = sin4->sin_addr.s_addr; - fl4->fl4_dport = sin4->sin_port; - fl4->fl4_sport = htons(wg->incoming_port); - fl4->flowi4_proto = IPPROTO_UDP; + memset(&fl->fl4, 0, sizeof(struct flowi4)); + fl->fl4.daddr = sin4->sin_addr.s_addr; + fl->fl4.fl4_dport = sin4->sin_port; + fl->fl4.fl4_sport = htons(wg->incoming_port); + fl->fl4.flowi4_proto = IPPROTO_UDP; - security_sk_classify_flow(sock4, flowi4_to_flowi(fl4)); - rt = ip_route_output_flow(sock_net(sock4), fl4, sock4); + rt = dst_cache_get_ip4(cache, &fl->fl4.saddr); + if (rt) + return &rt->dst; + + security_sk_classify_flow(sock4, flowi4_to_flowi(&fl->fl4)); + rt = ip_route_output_flow(sock_net(sock4), &fl->fl4, sock4); if (unlikely(IS_ERR(rt))) - dst = ERR_PTR(PTR_ERR(rt)); - dst = &rt->dst; + return ERR_PTR(PTR_ERR(rt)); + dst_cache_set_ip4(cache, &rt->dst, fl->fl4.saddr); + return &rt->dst; } else if (addr->ss_family == AF_INET6) { #if IS_ENABLED(CONFIG_IPV6) int ret; struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)addr; + struct dst_entry *dst; if (unlikely(!sock6)) return ERR_PTR(-ENONET); - memset(fl6, 0, sizeof(struct flowi6)); - fl6->daddr = sin6->sin6_addr; - fl6->fl6_dport = sin6->sin6_port; - fl6->fl6_sport = htons(wg->incoming_port); - fl6->flowi6_oif = sin6->sin6_scope_id; - fl6->flowi6_proto = IPPROTO_UDP; + memset(&fl->fl6, 0, sizeof(struct flowi6)); + fl->fl6.daddr = sin6->sin6_addr; + fl->fl6.fl6_dport = sin6->sin6_port; + fl->fl6.fl6_sport = htons(wg->incoming_port); + fl->fl6.flowi6_oif = sin6->sin6_scope_id; + fl->fl6.flowi6_proto = IPPROTO_UDP; /* TODO: addr6->sin6_flowinfo */ - security_sk_classify_flow(sock6, flowi6_to_flowi(fl6)); - ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock6), sock6, &dst, fl6); + dst = dst_cache_get_ip6(cache, &fl->fl6.saddr); + if (dst) + return dst; + + security_sk_classify_flow(sock6, flowi6_to_flowi(&fl->fl6)); + ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock6), sock6, &dst, &fl->fl6); if (unlikely(ret)) - dst = ERR_PTR(ret); + return ERR_PTR(ret); + dst_cache_set_ip6(cache, dst, &fl->fl6.saddr); + return dst; #endif } - return dst; + return ERR_PTR(-EAFNOSUPPORT); } -static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_entry *dst, struct flowi4 *fl4, struct flowi6 *fl6, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, u8 dscp) +static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_entry *dst, union flowi46 *fl, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, u8 dscp) { int ret = -EAFNOSUPPORT; @@ -101,9 +118,9 @@ static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_e goto err; } udp_tunnel_xmit_skb((struct rtable *)dst, sock4, skb, - fl4->saddr, fl4->daddr, + fl->fl4.saddr, fl->fl4.daddr, dscp, ip4_dst_hoplimit(dst), 0, - fl4->fl4_sport, fl4->fl4_dport, + fl->fl4.fl4_sport, fl->fl4.fl4_dport, false, false); return 0; } else if (addr->ss_family == AF_INET6) { @@ -113,9 +130,9 @@ static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_e } #if IS_ENABLED(CONFIG_IPV6) udp_tunnel6_xmit_skb(dst, sock6, skb, dev, - &fl6->saddr, &fl6->daddr, + &fl->fl6.saddr, &fl->fl6.daddr, dscp, ip6_dst_hoplimit(dst), 0, - fl6->fl6_sport, fl6->fl6_dport, + fl->fl6.fl6_sport, fl->fl6.fl6_dport, false); return 0; #else @@ -129,34 +146,6 @@ err: return ret; } -static inline void __socket_set_peer_dst(struct wireguard_peer *peer) -{ - struct dst_entry *old_dst, *new_dst; - lockdep_assert_held(&peer->endpoint_lock); - - old_dst = peer->endpoint_dst; - peer->endpoint_dst = NULL; - wmb(); - if (old_dst) - dst_release(old_dst); - - rcu_read_lock(); - new_dst = route(peer->device, &peer->endpoint_flow.fl4, &peer->endpoint_flow.fl6, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6)); - rcu_read_unlock(); - - if (likely(!IS_ERR(new_dst))) { - peer->endpoint_dst = new_dst; - wmb(); - } -} - -void socket_set_peer_dst(struct wireguard_peer *peer) -{ - write_lock_bh(&peer->endpoint_lock); - __socket_set_peer_dst(peer); - write_unlock_bh(&peer->endpoint_lock); -} - void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr) { if (sockaddr->ss_family == AF_INET) { @@ -175,61 +164,42 @@ void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage * memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in6)); } else return; - __socket_set_peer_dst(peer); + dst_cache_reset(&peer->endpoint_cache); write_unlock_bh(&peer->endpoint_lock); return; out: read_unlock_bh(&peer->endpoint_lock); } -static inline struct dst_entry *peer_dst_get(struct wireguard_peer *peer) -{ - struct dst_entry *dst = NULL; - read_lock_bh(&peer->endpoint_lock); - - if (!peer->endpoint_dst || (peer->endpoint_dst->obsolete && !peer->endpoint_dst->ops->check(peer->endpoint_dst, 0))) { - read_unlock_bh(&peer->endpoint_lock); - socket_set_peer_dst(peer); - read_lock_bh(&peer->endpoint_lock); - if (!peer->endpoint_dst) - goto out; - } - - if (!atomic_inc_not_zero(&peer->endpoint_dst->__refcnt)) - goto out; - dst = peer->endpoint_dst; - -out: - read_unlock_bh(&peer->endpoint_lock); - return dst; -} - - int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds) { struct net_device *dev = netdev_pub(peer->device); struct dst_entry *dst; + union flowi46 fl; size_t skb_len = skb->len; int ret = 0; - dst = peer_dst_get(peer); + rcu_read_lock(); + read_lock_bh(&peer->endpoint_lock); + + dst = route(peer->device, &fl, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), &peer->endpoint_cache); if (unlikely(!dst)) { net_dbg_ratelimited("No route to %pISpfsc for peer %Lu\n", &peer->endpoint_addr, peer->internal_id); kfree_skb(skb); - return -EHOSTUNREACH; + ret = -EHOSTUNREACH; + goto out; } else if (unlikely(dst->dev == dev)) { net_dbg_ratelimited("Avoiding routing loop to %pISpfsc for peer %Lu\n", &peer->endpoint_addr, peer->internal_id); kfree_skb(skb); - return -ELOOP; + ret = -ELOOP; + goto out; } - rcu_read_lock(); - read_lock_bh(&peer->endpoint_lock); - - ret = send(dev, skb, dst, &peer->endpoint_flow.fl4, &peer->endpoint_flow.fl6, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), ds); + ret = send(dev, skb, dst, &fl, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), ds); if (!ret) peer->tx_bytes += skb_len; +out: read_unlock_bh(&peer->endpoint_lock); rcu_read_unlock(); @@ -250,12 +220,9 @@ static int send_to_sockaddr(struct sk_buff *skb, struct wireguard_device *wg, st { struct dst_entry *dst; struct net_device *dev = netdev_pub(wg); - union { - struct flowi4 fl4; - struct flowi6 fl6; - } fl; + union flowi46 fl; - dst = route(wg, &fl.fl4, &fl.fl6, addr, sock4, sock6); + dst = route(wg, &fl, addr, sock4, sock6, NULL); if (IS_ERR(dst)) { net_dbg_ratelimited("No route to %pISpfsc\n", addr); kfree_skb(skb); @@ -267,7 +234,7 @@ static int send_to_sockaddr(struct sk_buff *skb, struct wireguard_device *wg, st return -ELOOP; } - return send(dev, skb, dst, &fl.fl4, &fl.fl6, addr, sock4, sock6, 0); + return send(dev, skb, dst, &fl, addr, sock4, sock6, 0); } int socket_send_buffer_as_reply_to_skb(struct sk_buff *in_skb, void *out_buffer, size_t len, struct wireguard_device *wg) diff --git a/src/socket.h b/src/socket.h index 2fe1bbf..5ab1365 100644 --- a/src/socket.h +++ b/src/socket.h @@ -20,6 +20,5 @@ int socket_send_buffer_as_reply_to_skb(struct sk_buff *in_skb, void *out_buffer, int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb); void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr); -void socket_set_peer_dst(struct wireguard_peer *peer); #endif |