diff --git a/include/net/net_namespace.h b/include/net/net_namespace.h index 9f1bb06ebb99ab60c506740490281badca6b5625..9f5162b0030f9ec3e405e4304454c1f3cec83e34 100644 --- a/include/net/net_namespace.h +++ b/include/net/net_namespace.h @@ -293,6 +293,7 @@ static inline int check_net(const struct net *net) } void net_drop_ns(void *); +void net_passive_dec(struct net *net); #else @@ -322,8 +323,17 @@ static inline int check_net(const struct net *net) } #define net_drop_ns NULL + +static inline void net_passive_dec(struct net *net) +{ + refcount_dec(&net->passive); +} #endif +static inline void net_passive_inc(struct net *net) +{ + refcount_inc(&net->passive); +} static inline void __netns_tracker_alloc(struct net *net, netns_tracker *tracker, diff --git a/include/net/sock.h b/include/net/sock.h index 330afe640f028c9bef1a082d26532a98476c56be..81050e48698e1f4a9d0bd93e2e2f2688079c975c 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1886,6 +1886,7 @@ static inline bool sock_allow_reclassification(const struct sock *csk) struct sock *sk_alloc(struct net *net, int family, gfp_t priority, struct proto *prot, int kern); void sk_free(struct sock *sk); +void sk_net_refcnt_upgrade(struct sock *sk); void sk_destruct(struct sock *sk); struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority); void sk_free_unlock_clone(struct sock *sk); diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c index 92b7fea4d495cfc0de9366101020677b188f2fc5..152761d5ca2fe748ae3c3dc0fcdd4e1913a436eb 100644 --- a/net/core/net_namespace.c +++ b/net/core/net_namespace.c @@ -457,7 +457,7 @@ static void net_complete_free(void) } -static void net_free(struct net *net) +void net_passive_dec(struct net *net) { if (refcount_dec_and_test(&net->passive)) { kfree(rcu_access_pointer(net->gen)); @@ -475,7 +475,7 @@ void net_drop_ns(void *p) struct net *net = (struct net *)p; if (net) - net_free(net); + net_passive_dec(net); } struct net *copy_net_ns(unsigned long flags, @@ -517,7 +517,7 @@ struct net *copy_net_ns(unsigned long flags, key_remove_domain(net->key_domain); #endif put_user_ns(user_ns); - net_free(net); + net_passive_dec(net); dec_ucounts: dec_net_namespaces(ucounts); return ERR_PTR(rv); @@ -653,7 +653,7 @@ static void cleanup_net(struct work_struct *work) key_remove_domain(net->key_domain); #endif put_user_ns(net->user_ns); - net_free(net); + net_passive_dec(net); } } diff --git a/net/core/sock.c b/net/core/sock.c index 93ee2bba32f64748d8f6a20a9cc3f063a561b95c..1e27275bb24f297b40bb48dd68ce8c597a4780ed 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -2159,6 +2159,7 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority, get_net_track(net, &sk->ns_tracker, priority); sock_inuse_add(net, 1); } else { + net_passive_inc(net); __netns_tracker_alloc(net, &sk->ns_tracker, false, priority); } @@ -2183,6 +2184,7 @@ EXPORT_SYMBOL(sk_alloc); static void __sk_destruct(struct rcu_head *head) { struct sock *sk = container_of(head, struct sock, sk_rcu); + struct net *net = sock_net(sk); struct sk_filter *filter; if (sk->sk_destruct) @@ -2214,14 +2216,28 @@ static void __sk_destruct(struct rcu_head *head) put_cred(sk->sk_peer_cred); put_pid(sk->sk_peer_pid); - if (likely(sk->sk_net_refcnt)) - put_net_track(sock_net(sk), &sk->ns_tracker); - else - __netns_tracker_free(sock_net(sk), &sk->ns_tracker, false); - + if (likely(sk->sk_net_refcnt)) { + put_net_track(net, &sk->ns_tracker); + } else { + __netns_tracker_free(net, &sk->ns_tracker, false); + net_passive_dec(net); + } sk_prot_free(sk->sk_prot_creator, sk); } +void sk_net_refcnt_upgrade(struct sock *sk) +{ + struct net *net = sock_net(sk); + + WARN_ON_ONCE(sk->sk_net_refcnt); + __netns_tracker_free(net, &sk->ns_tracker, false); + net_passive_dec(net); + sk->sk_net_refcnt = 1; + get_net_track(net, &sk->ns_tracker, GFP_KERNEL); + sock_inuse_add(net, 1); +} +EXPORT_SYMBOL_GPL(sk_net_refcnt_upgrade); + void sk_destruct(struct sock *sk) { bool use_call_rcu = sock_flag(sk, SOCK_RCU_FREE); @@ -2313,6 +2329,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) * is not properly dismantling its kernel sockets at netns * destroy time. */ + net_passive_inc(sock_net(newsk)); __netns_tracker_alloc(sock_net(newsk), &newsk->ns_tracker, false, priority); } diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c index a7b12181be1f566886ef83f0768119de7183db3e..3f8403077f0c11710730708805c0881d3402f89f 100644 --- a/net/mptcp/subflow.c +++ b/net/mptcp/subflow.c @@ -1726,10 +1726,7 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family, * needs it. * Update ns_tracker to current stack trace and refcounted tracker. */ - __netns_tracker_free(net, &sf->sk->ns_tracker, false); - sf->sk->sk_net_refcnt = 1; - get_net_track(net, &sf->sk->ns_tracker, GFP_KERNEL); - sock_inuse_add(net, 1); + sk_net_refcnt_upgrade(sf->sk); err = tcp_set_ulp(sf->sk, "mptcp"); if (err) goto err_free; diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 4aa2cbe9d6fa69957fd99025ba72f5932fdc8287..cdb66eb0c3fa52f07a4939f9d2c85a15f214077a 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -799,16 +799,6 @@ static int netlink_release(struct socket *sock) sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1); - /* Because struct net might disappear soon, do not keep a pointer. */ - if (!sk->sk_net_refcnt && sock_net(sk) != &init_net) { - __netns_tracker_free(sock_net(sk), &sk->ns_tracker, false); - /* Because of deferred_put_nlk_sk and use of work queue, - * it is possible netns will be freed before this socket. - */ - sock_net_set(sk, &init_net); - __netns_tracker_alloc(&init_net, &sk->ns_tracker, - false, GFP_KERNEL); - } call_rcu(&nlk->rcu, deferred_put_nlk_sk); return 0; } diff --git a/net/rds/tcp.c b/net/rds/tcp.c index fb6eaea3601480948f5b8a0234e03ff550844ddd..e5d90e4beb35b6e7c639615d290648614e68d60b 100644 --- a/net/rds/tcp.c +++ b/net/rds/tcp.c @@ -505,12 +505,8 @@ bool rds_tcp_tune(struct socket *sock) release_sock(sk); return false; } - /* Update ns_tracker to current stack trace and refcounted tracker */ - __netns_tracker_free(net, &sk->ns_tracker, false); - - sk->sk_net_refcnt = 1; - netns_tracker_alloc(net, &sk->ns_tracker, GFP_KERNEL); - sock_inuse_add(net, 1); + sk_net_refcnt_upgrade(sk); + put_net(net); } rtn = net_generic(net, rds_tcp_netid); if (rtn->sndbuf_size > 0) { diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index df0654d7f3528596a956874782153ed4aa1827a1..99a815bc42f80e004fe621996209abcf17cfa56d 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -3344,10 +3344,7 @@ int smc_create_clcsk(struct net *net, struct sock *sk, int family) * which need net ref. */ sk = smc->clcsock->sk; - __netns_tracker_free(net, &sk->ns_tracker, false); - sk->sk_net_refcnt = 1; - get_net_track(net, &sk->ns_tracker, GFP_KERNEL); - sock_inuse_add(net, 1); + sk_net_refcnt_upgrade(sk); return 0; } diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index 8d760f8fc4b5a63460e150114b03a2d39a4efa6d..f591ead0d96153c997adaf9dca6cf172cf5f78bf 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -1552,10 +1552,7 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, newlen = error; if (protocol == IPPROTO_TCP) { - __netns_tracker_free(net, &sock->sk->ns_tracker, false); - sock->sk->sk_net_refcnt = 1; - get_net_track(net, &sock->sk->ns_tracker, GFP_KERNEL); - sock_inuse_add(net, 1); + sk_net_refcnt_upgrade(sock->sk); if ((error = kernel_listen(sock, 64)) < 0) goto bummer; } diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 29df05879c8e956783469f16a2d693efdba6e33e..4e6313fac366e6c90576f9f9334eaae42fecb47b 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -1921,12 +1921,8 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt, goto out; } - if (protocol == IPPROTO_TCP) { - __netns_tracker_free(xprt->xprt_net, &sock->sk->ns_tracker, false); - sock->sk->sk_net_refcnt = 1; - get_net_track(xprt->xprt_net, &sock->sk->ns_tracker, GFP_KERNEL); - sock_inuse_add(xprt->xprt_net, 1); - } + if (protocol == IPPROTO_TCP) + sk_net_refcnt_upgrade(sock->sk); filp = sock_alloc_file(sock, O_NONBLOCK, NULL); if (IS_ERR(filp))