diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h index 3ecfeadbfa0694ad129bef20466de6b6c8b5adba..bdc4c38baa9f0695e668d2f3bcca0789d16a08f5 100644 --- a/include/net/inet_hashtables.h +++ b/include/net/inet_hashtables.h @@ -89,6 +89,7 @@ struct inet_bind_bucket { bool fast_ipv6_only; struct hlist_node node; struct hlist_head owners; + struct rcu_head rcu; }; struct inet_bind2_bucket { @@ -231,8 +232,7 @@ struct inet_bind_bucket * inet_bind_bucket_create(struct kmem_cache *cachep, struct net *net, struct inet_bind_hashbucket *head, const unsigned short snum, int l3mdev); -void inet_bind_bucket_destroy(struct kmem_cache *cachep, - struct inet_bind_bucket *tb); +void inet_bind_bucket_destroy(struct inet_bind_bucket *tb); bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net, unsigned short port, @@ -532,9 +532,12 @@ static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr) int __inet_hash_connect(struct inet_timewait_death_row *death_row, struct sock *sk, u64 port_offset, + u32 hash_port0, int (*check_established)(struct inet_timewait_death_row *, struct sock *, __u16, - struct inet_timewait_sock **)); + struct inet_timewait_sock **, + bool rcu_lookup, + u32 hash)); int inet_hash_connect(struct inet_timewait_death_row *death_row, struct sock *sk); diff --git a/include/net/ip.h b/include/net/ip.h index ae68d3f157a8e5446fd0bb360075fafabd2ea481..a21f2af40be0460ede7cf2aaece23fa8e98790b5 100644 --- a/include/net/ip.h +++ b/include/net/ip.h @@ -353,7 +353,7 @@ void inet_get_local_port_range(const struct net *net, int *low, int *high); void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high); #ifdef CONFIG_SYSCTL -static inline bool inet_is_local_reserved_port(struct net *net, unsigned short port) +static inline bool inet_is_local_reserved_port(const struct net *net, unsigned short port) { if (!net->ipv4.sysctl_local_reserved_ports) return false; diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index bd032ac2376ed7d705012bee86ca45ed42a14001..177dac731fcd67357f521d80174460c91d0f70f9 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -155,10 +155,11 @@ static bool inet_use_bhash2_on_bind(const struct sock *sk) { #if IS_ENABLED(CONFIG_IPV6) if (sk->sk_family == AF_INET6) { - int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr); + if (ipv6_addr_any(&sk->sk_v6_rcv_saddr)) + return false; - return addr_type != IPV6_ADDR_ANY && - addr_type != IPV6_ADDR_MAPPED; + if (!ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + return true; } #endif return sk->sk_rcv_saddr != htonl(INADDR_ANY); @@ -198,8 +199,15 @@ static bool __inet_bhash2_conflict(const struct sock *sk, struct sock *sk2, kuid_t sk_uid, bool relax, bool reuseport_cb_ok, bool reuseport_ok) { - if (sk->sk_family == AF_INET && ipv6_only_sock(sk2)) - return false; + if (ipv6_only_sock(sk2)) { + if (sk->sk_family == AF_INET) + return false; + +#if IS_ENABLED(CONFIG_IPV6) + if (ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + return false; +#endif + } return inet_bind_conflict(sk, sk2, sk_uid, relax, reuseport_cb_ok, reuseport_ok); @@ -593,7 +601,7 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum) fail_unlock: if (ret) { if (bhash_created) - inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); + inet_bind_bucket_destroy(tb); if (bhash2_created) inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, tb2); diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index 9ecabf87b4740bcf45864c3249f0f5a473c3fe0f..5374ef7a7bf630ca2ee9c7eb5b0e2b2dfa73d538 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -38,8 +38,8 @@ u32 inet_ehashfn(const struct net *net, const __be32 laddr, net_get_random_once(&inet_ehash_secret, sizeof(inet_ehash_secret)); - return __inet_ehashfn(laddr, lport, faddr, fport, - inet_ehash_secret + net_hash_mix(net)); + return lport + __inet_ehashfn(laddr, 0, faddr, fport, + inet_ehash_secret + net_hash_mix(net)); } EXPORT_SYMBOL_GPL(inet_ehashfn); @@ -79,7 +79,7 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, tb->fastreuse = 0; tb->fastreuseport = 0; INIT_HLIST_HEAD(&tb->owners); - hlist_add_head(&tb->node, &head->chain); + hlist_add_head_rcu(&tb->node, &head->chain); } return tb; } @@ -87,11 +87,11 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, /* * Caller must hold hashbucket lock for this tb with local BH disabled */ -void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb) +void inet_bind_bucket_destroy(struct inet_bind_bucket *tb) { if (hlist_empty(&tb->owners)) { - __hlist_del(&tb->node); - kmem_cache_free(cachep, tb); + hlist_del_rcu(&tb->node); + kfree_rcu(tb, rcu); } } @@ -197,7 +197,7 @@ static void __inet_put_port(struct sock *sk) __sk_del_bind_node(sk); inet_csk(sk)->icsk_bind_hash = NULL; inet_sk(sk)->inet_num = 0; - inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + inet_bind_bucket_destroy(tb); spin_lock(&head2->lock); if (inet_csk(sk)->icsk_bind2_hash) { @@ -293,7 +293,7 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child) error: if (created_inet_bind_bucket) - inet_bind_bucket_destroy(table->bind_bucket_cachep, tb); + inet_bind_bucket_destroy(tb); spin_unlock(&head2->lock); spin_unlock(&head->lock); return -ENOMEM; @@ -545,7 +545,9 @@ EXPORT_SYMBOL_GPL(__inet_lookup_established); /* called with local bh disabled */ static int __inet_check_established(struct inet_timewait_death_row *death_row, struct sock *sk, __u16 lport, - struct inet_timewait_sock **twp) + struct inet_timewait_sock **twp, + bool rcu_lookup, + u32 hash) { struct inet_hashinfo *hinfo = death_row->hashinfo; struct inet_sock *inet = inet_sk(sk); @@ -556,14 +558,25 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, int sdif = l3mdev_master_ifindex_by_index(net, dif); INET_ADDR_COOKIE(acookie, saddr, daddr); const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); - unsigned int hash = inet_ehashfn(net, daddr, lport, - saddr, inet->inet_dport); struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); - spinlock_t *lock = inet_ehash_lockp(hinfo, hash); - struct sock *sk2; - const struct hlist_nulls_node *node; struct inet_timewait_sock *tw = NULL; + const struct hlist_nulls_node *node; + struct sock *sk2; + spinlock_t *lock; + if (rcu_lookup) { + sk_nulls_for_each(sk2, node, &head->chain) { + if (sk2->sk_hash != hash || + !inet_match(net, sk2, acookie, ports, dif, sdif)) + continue; + if (sk2->sk_state == TCP_TIME_WAIT) + break; + return -EADDRNOTAVAIL; + } + return 0; + } + + lock = inet_ehash_lockp(hinfo, hash); spin_lock(lock); sk_nulls_for_each(sk2, node, &head->chain) { @@ -843,7 +856,8 @@ bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const return ipv6_addr_any(&tb->v6_rcv_saddr) || ipv6_addr_v4mapped_any(&tb->v6_rcv_saddr); - return false; + return ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr) && + tb->rcv_saddr == 0; } if (sk->sk_family == AF_INET6) @@ -1001,8 +1015,10 @@ static u32 *table_perturb; int __inet_hash_connect(struct inet_timewait_death_row *death_row, struct sock *sk, u64 port_offset, + u32 hash_port0, int (*check_established)(struct inet_timewait_death_row *, - struct sock *, __u16, struct inet_timewait_sock **)) + struct sock *, __u16, struct inet_timewait_sock **, + bool rcu_lookup, u32 hash)) { struct inet_hashinfo *hinfo = death_row->hashinfo; struct inet_bind_hashbucket *head, *head2; @@ -1019,7 +1035,8 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, if (port) { local_bh_disable(); - ret = check_established(death_row, sk, port, NULL); + ret = check_established(death_row, sk, port, NULL, false, + hash_port0 + port); local_bh_enable(); return ret; } @@ -1057,6 +1074,22 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, continue; head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)]; + rcu_read_lock(); + hlist_for_each_entry_rcu(tb, &head->chain, node) { + if (!inet_bind_bucket_match(tb, net, port, l3mdev)) + continue; + if (tb->fastreuse >= 0 || tb->fastreuseport >= 0) { + rcu_read_unlock(); + goto next_port; + } + if (!check_established(death_row, sk, port, &tw, true, + hash_port0 + port)) + break; + rcu_read_unlock(); + goto next_port; + } + rcu_read_unlock(); + spin_lock_bh(&head->lock); /* Does not bother with rcv_saddr checks, because @@ -1066,12 +1099,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, if (inet_bind_bucket_match(tb, net, port, l3mdev)) { if (tb->fastreuse >= 0 || tb->fastreuseport >= 0) - goto next_port; + goto next_port_unlock; WARN_ON(hlist_empty(&tb->owners)); if (!check_established(death_row, sk, - port, &tw)) + port, &tw, false, + hash_port0 + port)) goto ok; - goto next_port; + goto next_port_unlock; } } @@ -1085,8 +1119,9 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, tb->fastreuse = -1; tb->fastreuseport = -1; goto ok; -next_port: +next_port_unlock: spin_unlock_bh(&head->lock); +next_port: cond_resched(); } @@ -1157,7 +1192,7 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, spin_unlock(&head2->lock); if (tb_created) - inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); + inet_bind_bucket_destroy(tb); spin_unlock(&head->lock); if (tw) @@ -1174,11 +1209,18 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, int inet_hash_connect(struct inet_timewait_death_row *death_row, struct sock *sk) { + const struct inet_sock *inet = inet_sk(sk); + const struct net *net = sock_net(sk); u64 port_offset = 0; + u32 hash_port0; if (!inet_sk(sk)->inet_num) port_offset = inet_sk_port_offset(sk); - return __inet_hash_connect(death_row, sk, port_offset, + + hash_port0 = inet_ehashfn(net, inet->inet_rcv_saddr, 0, + inet->inet_daddr, inet->inet_dport); + + return __inet_hash_connect(death_row, sk, port_offset, hash_port0, __inet_check_established); } EXPORT_SYMBOL_GPL(inet_hash_connect); diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c index fff53144250c5196c52c475dd8558a8cebd7eee3..f26745fa3d632c23efa144cac6a060239bc06e86 100644 --- a/net/ipv4/inet_timewait_sock.c +++ b/net/ipv4/inet_timewait_sock.c @@ -37,7 +37,7 @@ void inet_twsk_bind_unhash(struct inet_timewait_sock *tw, __hlist_del(&tw->tw_bind_node); tw->tw_tb = NULL; - inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + inet_bind_bucket_destroy(tb); __hlist_del(&tw->tw_bind2_node); tw->tw_tb2 = NULL; diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c index b0e8d278e8a9b794d0001efdc0f43716f9a34f8f..6a838ce959ed5b97ebf2d648f5d60f0f3b58d72d 100644 --- a/net/ipv6/inet6_hashtables.c +++ b/net/ipv6/inet6_hashtables.c @@ -36,8 +36,8 @@ u32 inet6_ehashfn(const struct net *net, lhash = (__force u32)laddr->s6_addr32[3]; fhash = __ipv6_addr_jhash(faddr, ipv6_hash_secret); - return __inet6_ehashfn(lhash, lport, fhash, fport, - inet6_ehash_secret + net_hash_mix(net)); + return lport + __inet6_ehashfn(lhash, 0, fhash, fport, + inet6_ehash_secret + net_hash_mix(net)); } EXPORT_SYMBOL_GPL(inet6_ehashfn); @@ -263,7 +263,9 @@ EXPORT_SYMBOL_GPL(inet6_lookup); static int __inet6_check_established(struct inet_timewait_death_row *death_row, struct sock *sk, const __u16 lport, - struct inet_timewait_sock **twp) + struct inet_timewait_sock **twp, + bool rcu_lookup, + u32 hash) { struct inet_hashinfo *hinfo = death_row->hashinfo; struct inet_sock *inet = inet_sk(sk); @@ -273,14 +275,26 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row, struct net *net = sock_net(sk); const int sdif = l3mdev_master_ifindex_by_index(net, dif); const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); - const unsigned int hash = inet6_ehashfn(net, daddr, lport, saddr, - inet->inet_dport); struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); - spinlock_t *lock = inet_ehash_lockp(hinfo, hash); - struct sock *sk2; - const struct hlist_nulls_node *node; struct inet_timewait_sock *tw = NULL; + const struct hlist_nulls_node *node; + struct sock *sk2; + spinlock_t *lock; + + if (rcu_lookup) { + sk_nulls_for_each(sk2, node, &head->chain) { + if (sk2->sk_hash != hash || + !inet6_match(net, sk2, saddr, daddr, + ports, dif, sdif)) + continue; + if (sk2->sk_state == TCP_TIME_WAIT) + break; + return -EADDRNOTAVAIL; + } + return 0; + } + lock = inet_ehash_lockp(hinfo, hash); spin_lock(lock); sk_nulls_for_each(sk2, node, &head->chain) { @@ -338,11 +352,19 @@ static u64 inet6_sk_port_offset(const struct sock *sk) int inet6_hash_connect(struct inet_timewait_death_row *death_row, struct sock *sk) { + const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr; + const struct in6_addr *saddr = &sk->sk_v6_daddr; + const struct inet_sock *inet = inet_sk(sk); + const struct net *net = sock_net(sk); u64 port_offset = 0; + u32 hash_port0; if (!inet_sk(sk)->inet_num) port_offset = inet6_sk_port_offset(sk); - return __inet_hash_connect(death_row, sk, port_offset, + + hash_port0 = inet6_ehashfn(net, daddr, 0, saddr, inet->inet_dport); + + return __inet_hash_connect(death_row, sk, port_offset, hash_port0, __inet6_check_established); } EXPORT_SYMBOL_GPL(inet6_hash_connect);