diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c index 53e0b08b11232e21e258515326dd63149f60d04b..8847fd02f3d98640e7bc40d85fb4e8a55f4f96e9 100644 --- a/net/mptcp/pm.c +++ b/net/mptcp/pm.c @@ -237,7 +237,9 @@ void mptcp_pm_add_addr_received(const struct sock *ssk, } else { __MPTCP_INC_STATS(sock_net((struct sock *)msk), MPTCP_MIB_ADDADDRDROP); } - } else if (!READ_ONCE(pm->accept_addr)) { + /* id0 should not have a different address */ + } else if ((addr->id == 0 && !mptcp_pm_nl_is_init_remote_addr(msk, addr)) || + (addr->id > 0 && !READ_ONCE(pm->accept_addr))) { mptcp_pm_announce_addr(msk, addr, true); mptcp_pm_add_addr_send_ack(msk); } else if (mptcp_pm_schedule_work(msk, MPTCP_PM_ADD_ADDR_RECEIVED)) { diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index aa56696cdc78c2a1ed95897362e7733eb88a25b5..e0f6f72bf7d568fe5115a3dc023b5ff1f80ee8ae 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -743,6 +743,15 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) } } +bool mptcp_pm_nl_is_init_remote_addr(struct mptcp_sock *msk, + const struct mptcp_addr_info *remote) +{ + struct mptcp_addr_info mpc_remote; + + remote_address((struct sock_common *)msk, &mpc_remote); + return mptcp_addresses_equal(&mpc_remote, remote, remote->port); +} + void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk) { struct mptcp_subflow_context *subflow; @@ -841,7 +850,7 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk, mptcp_close_ssk(sk, ssk, subflow); spin_lock_bh(&msk->pm.lock); - removed = true; + removed |= subflow->request_join; if (rm_type == MPTCP_MIB_RMSUBFLOW) __MPTCP_INC_STATS(sock_net(sk), rm_type); } @@ -855,9 +864,13 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk, if (!mptcp_pm_is_kernel(msk)) continue; - if (rm_type == MPTCP_MIB_RMADDR) { - msk->pm.add_addr_accepted--; - WRITE_ONCE(msk->pm.accept_addr, true); + if (rm_type == MPTCP_MIB_RMADDR && rm_id && + !WARN_ON_ONCE(msk->pm.add_addr_accepted == 0)) { + /* Note: if the subflow has been closed before, this + * add_addr_accepted counter will not be decremented. + */ + if (--msk->pm.add_addr_accepted < mptcp_pm_get_add_addr_accept_max(msk)) + WRITE_ONCE(msk->pm.accept_addr, true); } else if (rm_type == MPTCP_MIB_RMSUBFLOW) { msk->pm.local_addr_used--; } diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index 524516025b6e6689eb8aada89fc532ce4203e15d..aaeff21553d10d142d2b27c2365b683b9b032d57 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -907,6 +907,8 @@ void mptcp_pm_add_addr_received(const struct sock *ssk, void mptcp_pm_add_addr_echoed(struct mptcp_sock *msk, const struct mptcp_addr_info *addr); void mptcp_pm_add_addr_send_ack(struct mptcp_sock *msk); +bool mptcp_pm_nl_is_init_remote_addr(struct mptcp_sock *msk, + const struct mptcp_addr_info *remote); void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk); void mptcp_pm_rm_addr_received(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list);