diff --git a/include/linux/util_macros.h b/include/linux/util_macros.h index 72299f261b253392ead9e08977d033f43bf6b77c..43db6e47503c725b69c2c148fe1f190c0c26aa9d 100644 --- a/include/linux/util_macros.h +++ b/include/linux/util_macros.h @@ -38,4 +38,16 @@ */ #define find_closest_descending(x, a, as) __find_closest(x, a, as, >=) +/** + * is_insidevar - check if the @ptr points inside the @var memory range. + * @ptr: the pointer to a memory address. + * @var: the variable which address and size identify the memory range. + * + * Evaluates to true if the address in @ptr lies within the memory + * range allocated to @var. + */ +#define is_insidevar(ptr, var) \ + ((uintptr_t)(ptr) >= (uintptr_t)(var) && \ + (uintptr_t)(ptr) < (uintptr_t)(var) + sizeof(var)) + #endif diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 9cee55321ef2b1c053d47e63e8505e3c93dc632d..83e3ef655684c1b84ba0ab1024a6dd371e9bf4a6 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -1561,15 +1561,16 @@ void sock_map_unhash(struct sock *sk) psock = sk_psock(sk); if (unlikely(!psock)) { rcu_read_unlock(); - if (sk->sk_prot->unhash) - sk->sk_prot->unhash(sk); - return; + saved_unhash = READ_ONCE(sk->sk_prot)->unhash; + } else { + saved_unhash = psock->saved_unhash; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); } - - saved_unhash = psock->saved_unhash; - sock_map_remove_links(sk, psock); - rcu_read_unlock(); - saved_unhash(sk); + if (WARN_ON_ONCE(saved_unhash == sock_map_unhash)) + return; + if (saved_unhash) + saved_unhash(sk); } void sock_map_destroy(struct sock *sk) @@ -1581,17 +1582,18 @@ void sock_map_destroy(struct sock *sk) psock = sk_psock_get(sk); if (unlikely(!psock)) { rcu_read_unlock(); - if (sk->sk_prot->destroy) - sk->sk_prot->destroy(sk); - return; + saved_destroy = READ_ONCE(sk->sk_prot)->destroy; + } else { + saved_destroy = psock->saved_destroy; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); + sk_psock_stop(psock, false); + sk_psock_put(sk, psock); } - - saved_destroy = psock->saved_destroy; - sock_map_remove_links(sk, psock); - rcu_read_unlock(); - sk_psock_stop(psock, false); - sk_psock_put(sk, psock); - saved_destroy(sk); + if (WARN_ON_ONCE(saved_destroy == sock_map_destroy)) + return; + if (saved_destroy) + saved_destroy(sk); } EXPORT_SYMBOL_GPL(sock_map_destroy); @@ -1606,15 +1608,20 @@ void sock_map_close(struct sock *sk, long timeout) if (unlikely(!psock)) { rcu_read_unlock(); release_sock(sk); - return sk->sk_prot->close(sk, timeout); + saved_close = READ_ONCE(sk->sk_prot)->close; + } else { + saved_close = psock->saved_close; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); + sk_psock_stop(psock, true); + sk_psock_put(sk, psock); + release_sock(sk); } - - saved_close = psock->saved_close; - sock_map_remove_links(sk, psock); - rcu_read_unlock(); - sk_psock_stop(psock, true); - sk_psock_put(sk, psock); - release_sock(sk); + /* Make sure we do not recurse. This is a bug. + * Leak the socket instead of crashing on a stack overflow. + */ + if (WARN_ON_ONCE(saved_close == sock_map_close)) + return; saved_close(sk, timeout); } diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index 438992e28efefce3bc32eb9903df2b7f67a131bc..cd214291290c5b1c4a403d723da04685d316085d 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -719,10 +720,9 @@ struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) */ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) { - int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; struct proto *prot = newsk->sk_prot; - if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE]) + if (is_insidevar(prot, tcp_bpf_prots)) newsk->sk_prot = sk->sk_prot_creator; } #endif /* CONFIG_BPF_STREAM_PARSER */