summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/config.c2
-rw-r--r--src/device.c1
-rw-r--r--src/peer.c8
-rw-r--r--src/peer.h7
-rw-r--r--src/socket.c151
-rw-r--r--src/socket.h1
6 files changed, 68 insertions, 102 deletions
diff --git a/src/config.c b/src/config.c
index 7f217ff..a98ecd9 100644
--- a/src/config.c
+++ b/src/config.c
@@ -11,7 +11,7 @@
static int set_peer_dst(struct wireguard_peer *peer, void *data)
{
- socket_set_peer_dst(peer);
+ dst_cache_reset(&peer->endpoint_cache);
return 0;
}
diff --git a/src/device.c b/src/device.c
index dad5521..41df285 100644
--- a/src/device.c
+++ b/src/device.c
@@ -37,7 +37,6 @@ static void uninit(struct net_device *dev)
static int open_peer(struct wireguard_peer *peer, void *data)
{
- socket_set_peer_dst(peer);
timers_init_peer(peer);
packet_send_queue(peer);
if (peer->persistent_keepalive_interval)
diff --git a/src/peer.c b/src/peer.c
index e1aaf6b..307e499 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -26,6 +26,11 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_
if (!peer)
return NULL;
+ if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)) {
+ kfree(peer);
+ return NULL;
+ }
+
peer->internal_id = atomic64_inc_return(&peer_counter);
peer->device = wg;
cookie_init(&peer->latest_cookie);
@@ -82,8 +87,7 @@ static void rcu_release(struct rcu_head *rcu)
struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu);
pr_debug("Peer %Lu (%pISpfsc) destroyed\n", peer->internal_id, &peer->endpoint_addr);
skb_queue_purge(&peer->tx_packet_queue);
- if (peer->endpoint_dst)
- dst_release(peer->endpoint_dst);
+ dst_cache_destroy(&peer->endpoint_cache);
kzfree(peer);
}
diff --git a/src/peer.h b/src/peer.h
index 994b523..c97d934 100644
--- a/src/peer.h
+++ b/src/peer.h
@@ -10,17 +10,14 @@
#include <linux/netfilter.h>
#include <linux/spinlock.h>
#include <linux/kref.h>
+#include <net/dst_cache.h>
struct wireguard_device;
struct wireguard_peer {
struct wireguard_device *device;
struct sockaddr_storage endpoint_addr;
- struct dst_entry *endpoint_dst;
- union {
- struct flowi4 fl4;
- struct flowi6 fl6;
- } endpoint_flow;
+ struct dst_cache endpoint_cache;
rwlock_t endpoint_lock;
struct noise_handshake handshake;
struct noise_keypairs keypairs;
diff --git a/src/socket.c b/src/socket.c
index fc8863d..265c6dc 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -10,9 +10,16 @@
#include <linux/net.h>
#include <linux/if_vlan.h>
#include <linux/if_ether.h>
+#include <net/dst_cache.h>
#include <net/udp_tunnel.h>
#include <net/ipv6.h>
+
+union flowi46 {
+ struct flowi4 fl4;
+ struct flowi6 fl6;
+};
+
int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb)
{
struct iphdr *ip4;
@@ -41,10 +48,8 @@ int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb)
return 0;
}
-static inline struct dst_entry *route(struct wireguard_device *wg, struct flowi4 *fl4, struct flowi6 *fl6, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6)
+static inline struct dst_entry *route(struct wireguard_device *wg, union flowi46 *fl, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, struct dst_cache *cache)
{
- struct dst_entry *dst = ERR_PTR(-EAFNOSUPPORT);
-
if (addr->ss_family == AF_INET) {
struct rtable *rt;
struct sockaddr_in *sin4 = (struct sockaddr_in *)addr;
@@ -52,43 +57,55 @@ static inline struct dst_entry *route(struct wireguard_device *wg, struct flowi4
if (unlikely(!sock4))
return ERR_PTR(-ENONET);
- memset(fl4, 0, sizeof(struct flowi4));
- fl4->daddr = sin4->sin_addr.s_addr;
- fl4->fl4_dport = sin4->sin_port;
- fl4->fl4_sport = htons(wg->incoming_port);
- fl4->flowi4_proto = IPPROTO_UDP;
+ memset(&fl->fl4, 0, sizeof(struct flowi4));
+ fl->fl4.daddr = sin4->sin_addr.s_addr;
+ fl->fl4.fl4_dport = sin4->sin_port;
+ fl->fl4.fl4_sport = htons(wg->incoming_port);
+ fl->fl4.flowi4_proto = IPPROTO_UDP;
- security_sk_classify_flow(sock4, flowi4_to_flowi(fl4));
- rt = ip_route_output_flow(sock_net(sock4), fl4, sock4);
+ rt = dst_cache_get_ip4(cache, &fl->fl4.saddr);
+ if (rt)
+ return &rt->dst;
+
+ security_sk_classify_flow(sock4, flowi4_to_flowi(&fl->fl4));
+ rt = ip_route_output_flow(sock_net(sock4), &fl->fl4, sock4);
if (unlikely(IS_ERR(rt)))
- dst = ERR_PTR(PTR_ERR(rt));
- dst = &rt->dst;
+ return ERR_PTR(PTR_ERR(rt));
+ dst_cache_set_ip4(cache, &rt->dst, fl->fl4.saddr);
+ return &rt->dst;
} else if (addr->ss_family == AF_INET6) {
#if IS_ENABLED(CONFIG_IPV6)
int ret;
struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)addr;
+ struct dst_entry *dst;
if (unlikely(!sock6))
return ERR_PTR(-ENONET);
- memset(fl6, 0, sizeof(struct flowi6));
- fl6->daddr = sin6->sin6_addr;
- fl6->fl6_dport = sin6->sin6_port;
- fl6->fl6_sport = htons(wg->incoming_port);
- fl6->flowi6_oif = sin6->sin6_scope_id;
- fl6->flowi6_proto = IPPROTO_UDP;
+ memset(&fl->fl6, 0, sizeof(struct flowi6));
+ fl->fl6.daddr = sin6->sin6_addr;
+ fl->fl6.fl6_dport = sin6->sin6_port;
+ fl->fl6.fl6_sport = htons(wg->incoming_port);
+ fl->fl6.flowi6_oif = sin6->sin6_scope_id;
+ fl->fl6.flowi6_proto = IPPROTO_UDP;
/* TODO: addr6->sin6_flowinfo */
- security_sk_classify_flow(sock6, flowi6_to_flowi(fl6));
- ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock6), sock6, &dst, fl6);
+ dst = dst_cache_get_ip6(cache, &fl->fl6.saddr);
+ if (dst)
+ return dst;
+
+ security_sk_classify_flow(sock6, flowi6_to_flowi(&fl->fl6));
+ ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock6), sock6, &dst, &fl->fl6);
if (unlikely(ret))
- dst = ERR_PTR(ret);
+ return ERR_PTR(ret);
+ dst_cache_set_ip6(cache, dst, &fl->fl6.saddr);
+ return dst;
#endif
}
- return dst;
+ return ERR_PTR(-EAFNOSUPPORT);
}
-static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_entry *dst, struct flowi4 *fl4, struct flowi6 *fl6, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, u8 dscp)
+static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_entry *dst, union flowi46 *fl, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, u8 dscp)
{
int ret = -EAFNOSUPPORT;
@@ -101,9 +118,9 @@ static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_e
goto err;
}
udp_tunnel_xmit_skb((struct rtable *)dst, sock4, skb,
- fl4->saddr, fl4->daddr,
+ fl->fl4.saddr, fl->fl4.daddr,
dscp, ip4_dst_hoplimit(dst), 0,
- fl4->fl4_sport, fl4->fl4_dport,
+ fl->fl4.fl4_sport, fl->fl4.fl4_dport,
false, false);
return 0;
} else if (addr->ss_family == AF_INET6) {
@@ -113,9 +130,9 @@ static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_e
}
#if IS_ENABLED(CONFIG_IPV6)
udp_tunnel6_xmit_skb(dst, sock6, skb, dev,
- &fl6->saddr, &fl6->daddr,
+ &fl->fl6.saddr, &fl->fl6.daddr,
dscp, ip6_dst_hoplimit(dst), 0,
- fl6->fl6_sport, fl6->fl6_dport,
+ fl->fl6.fl6_sport, fl->fl6.fl6_dport,
false);
return 0;
#else
@@ -129,34 +146,6 @@ err:
return ret;
}
-static inline void __socket_set_peer_dst(struct wireguard_peer *peer)
-{
- struct dst_entry *old_dst, *new_dst;
- lockdep_assert_held(&peer->endpoint_lock);
-
- old_dst = peer->endpoint_dst;
- peer->endpoint_dst = NULL;
- wmb();
- if (old_dst)
- dst_release(old_dst);
-
- rcu_read_lock();
- new_dst = route(peer->device, &peer->endpoint_flow.fl4, &peer->endpoint_flow.fl6, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6));
- rcu_read_unlock();
-
- if (likely(!IS_ERR(new_dst))) {
- peer->endpoint_dst = new_dst;
- wmb();
- }
-}
-
-void socket_set_peer_dst(struct wireguard_peer *peer)
-{
- write_lock_bh(&peer->endpoint_lock);
- __socket_set_peer_dst(peer);
- write_unlock_bh(&peer->endpoint_lock);
-}
-
void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr)
{
if (sockaddr->ss_family == AF_INET) {
@@ -175,61 +164,42 @@ void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *
memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in6));
} else
return;
- __socket_set_peer_dst(peer);
+ dst_cache_reset(&peer->endpoint_cache);
write_unlock_bh(&peer->endpoint_lock);
return;
out:
read_unlock_bh(&peer->endpoint_lock);
}
-static inline struct dst_entry *peer_dst_get(struct wireguard_peer *peer)
-{
- struct dst_entry *dst = NULL;
- read_lock_bh(&peer->endpoint_lock);
-
- if (!peer->endpoint_dst || (peer->endpoint_dst->obsolete && !peer->endpoint_dst->ops->check(peer->endpoint_dst, 0))) {
- read_unlock_bh(&peer->endpoint_lock);
- socket_set_peer_dst(peer);
- read_lock_bh(&peer->endpoint_lock);
- if (!peer->endpoint_dst)
- goto out;
- }
-
- if (!atomic_inc_not_zero(&peer->endpoint_dst->__refcnt))
- goto out;
- dst = peer->endpoint_dst;
-
-out:
- read_unlock_bh(&peer->endpoint_lock);
- return dst;
-}
-
-
int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds)
{
struct net_device *dev = netdev_pub(peer->device);
struct dst_entry *dst;
+ union flowi46 fl;
size_t skb_len = skb->len;
int ret = 0;
- dst = peer_dst_get(peer);
+ rcu_read_lock();
+ read_lock_bh(&peer->endpoint_lock);
+
+ dst = route(peer->device, &fl, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), &peer->endpoint_cache);
if (unlikely(!dst)) {
net_dbg_ratelimited("No route to %pISpfsc for peer %Lu\n", &peer->endpoint_addr, peer->internal_id);
kfree_skb(skb);
- return -EHOSTUNREACH;
+ ret = -EHOSTUNREACH;
+ goto out;
} else if (unlikely(dst->dev == dev)) {
net_dbg_ratelimited("Avoiding routing loop to %pISpfsc for peer %Lu\n", &peer->endpoint_addr, peer->internal_id);
kfree_skb(skb);
- return -ELOOP;
+ ret = -ELOOP;
+ goto out;
}
- rcu_read_lock();
- read_lock_bh(&peer->endpoint_lock);
-
- ret = send(dev, skb, dst, &peer->endpoint_flow.fl4, &peer->endpoint_flow.fl6, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), ds);
+ ret = send(dev, skb, dst, &fl, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), ds);
if (!ret)
peer->tx_bytes += skb_len;
+out:
read_unlock_bh(&peer->endpoint_lock);
rcu_read_unlock();
@@ -250,12 +220,9 @@ static int send_to_sockaddr(struct sk_buff *skb, struct wireguard_device *wg, st
{
struct dst_entry *dst;
struct net_device *dev = netdev_pub(wg);
- union {
- struct flowi4 fl4;
- struct flowi6 fl6;
- } fl;
+ union flowi46 fl;
- dst = route(wg, &fl.fl4, &fl.fl6, addr, sock4, sock6);
+ dst = route(wg, &fl, addr, sock4, sock6, NULL);
if (IS_ERR(dst)) {
net_dbg_ratelimited("No route to %pISpfsc\n", addr);
kfree_skb(skb);
@@ -267,7 +234,7 @@ static int send_to_sockaddr(struct sk_buff *skb, struct wireguard_device *wg, st
return -ELOOP;
}
- return send(dev, skb, dst, &fl.fl4, &fl.fl6, addr, sock4, sock6, 0);
+ return send(dev, skb, dst, &fl, addr, sock4, sock6, 0);
}
int socket_send_buffer_as_reply_to_skb(struct sk_buff *in_skb, void *out_buffer, size_t len, struct wireguard_device *wg)
diff --git a/src/socket.h b/src/socket.h
index 2fe1bbf..5ab1365 100644
--- a/src/socket.h
+++ b/src/socket.h
@@ -20,6 +20,5 @@ int socket_send_buffer_as_reply_to_skb(struct sk_buff *in_skb, void *out_buffer,
int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb);
void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr);
-void socket_set_peer_dst(struct wireguard_peer *peer);
#endif