diff options
Diffstat (limited to 'src/socket.c')
-rw-r--r-- | src/socket.c | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/socket.c b/src/socket.c index cdb0f47..1dff57a 100644 --- a/src/socket.c +++ b/src/socket.c @@ -44,9 +44,14 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct if (!rt) { security_sk_classify_flow(sock, flowi4_to_flowi(&fl)); + if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, fl.saddr, RT_SCOPE_HOST))) { + endpoint->src4.s_addr = endpoint->src_if4 = fl.saddr = 0; + if (cache) + dst_cache_reset(cache); + } rt = ip_route_output_flow(sock_net(sock), &fl, sock); - if (unlikely(endpoint->src4.s_addr && ((IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && !inet_confirm_addr(sock_net(sock), rcu_dereference_bh(rt->dst.dev->ip_ptr), 0, fl.saddr, RT_SCOPE_HOST))))) { - endpoint->src4.s_addr = fl.saddr = 0; + if (unlikely(endpoint->src_if4 && ((IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && rt->dst.dev->ifindex != endpoint->src_if4)))) { + endpoint->src4.s_addr = endpoint->src_if4 = fl.saddr = 0; if (cache) dst_cache_reset(cache); if (!IS_ERR(rt)) @@ -204,6 +209,7 @@ int socket_endpoint_from_skb(struct endpoint *endpoint, struct sk_buff *skb) endpoint->addr4.sin_port = udp_hdr(skb)->source; endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr; endpoint->src4.s_addr = ip_hdr(skb)->daddr; + endpoint->src_if4 = skb->skb_iif; } else if (skb->protocol == htons(ETH_P_IPV6)) { endpoint->addr6.sin6_family = AF_INET6; endpoint->addr6.sin6_port = udp_hdr(skb)->source; @@ -223,12 +229,14 @@ void socket_set_peer_endpoint(struct wireguard_peer *peer, struct endpoint *endp if (likely(peer->endpoint.addr4.sin_family == AF_INET && peer->endpoint.addr4.sin_port == endpoint->addr4.sin_port && peer->endpoint.addr4.sin_addr.s_addr == endpoint->addr4.sin_addr.s_addr && - peer->endpoint.src4.s_addr == endpoint->src4.s_addr)) + peer->endpoint.src4.s_addr == endpoint->src4.s_addr && + peer->endpoint.src_if4 == endpoint->src_if4)) goto out; read_unlock_bh(&peer->endpoint_lock); write_lock_bh(&peer->endpoint_lock); peer->endpoint.addr4 = endpoint->addr4; peer->endpoint.src4 = endpoint->src4; + peer->endpoint.src_if4 = endpoint->src_if4; } else if (endpoint->addr.sa_family == AF_INET6) { read_lock_bh(&peer->endpoint_lock); if (likely(peer->endpoint.addr6.sin6_family == AF_INET6 && |