diff --git a/fs/ksmbd/connection.c b/fs/ksmbd/connection.c index cb9ce1dd27aeda2510fdf87414b6d8a5a597e631..bc958e43e80c42c8721226c530b5a5e4b681a9a4 100644 --- a/fs/ksmbd/connection.c +++ b/fs/ksmbd/connection.c @@ -78,6 +78,8 @@ struct ksmbd_conn *ksmbd_conn_alloc(void) spin_lock_init(&conn->llist_lock); INIT_LIST_HEAD(&conn->lock_list); + init_rwsem(&conn->session_lock); + down_write(&conn_list_lock); list_add(&conn->conns_list, &conn_list); up_write(&conn_list_lock); diff --git a/fs/ksmbd/connection.h b/fs/ksmbd/connection.h index 5f44bb27bd82bc0e078eb000e4003a008a7f48fd..cb6d67318d7c17b88361fb82b6961b721c8b2f0d 100644 --- a/fs/ksmbd/connection.h +++ b/fs/ksmbd/connection.h @@ -48,6 +48,7 @@ struct ksmbd_conn { struct ksmbd_transport *transport; struct nls_table *local_nls; struct list_head conns_list; + struct rw_semaphore session_lock; /* smb session 1 per user */ struct xarray sessions; unsigned long last_active; diff --git a/fs/ksmbd/mgmt/user_session.c b/fs/ksmbd/mgmt/user_session.c index d3f8e9d93c3b949b335486a02b4cefad26036e7e..aba451fda8fa7f6fddf0348f541c691754fd5326 100644 --- a/fs/ksmbd/mgmt/user_session.c +++ b/fs/ksmbd/mgmt/user_session.c @@ -181,18 +181,19 @@ static void ksmbd_expire_session(struct ksmbd_conn *conn) unsigned long id; struct ksmbd_session *sess; - down_write(&sessions_table_lock); + down_write(&conn->session_lock); xa_for_each(&conn->sessions, id, sess) { - if (sess->state != SMB2_SESSION_VALID || + if (atomic_read(&sess->refcnt) == 0 && + (sess->state != SMB2_SESSION_VALID || time_after(jiffies, - sess->last_active + SMB2_SESSION_TIMEOUT)) { + sess->last_active + SMB2_SESSION_TIMEOUT))) { xa_erase(&conn->sessions, sess->id); hash_del(&sess->hlist); ksmbd_session_destroy(sess); continue; } } - up_write(&sessions_table_lock); + up_write(&conn->session_lock); } int ksmbd_session_register(struct ksmbd_conn *conn, @@ -234,7 +235,9 @@ void ksmbd_sessions_deregister(struct ksmbd_conn *conn) } } } + up_write(&sessions_table_lock); + down_write(&conn->session_lock); xa_for_each(&conn->sessions, id, sess) { unsigned long chann_id; struct channel *chann; @@ -251,7 +254,7 @@ void ksmbd_sessions_deregister(struct ksmbd_conn *conn) ksmbd_session_destroy(sess); } } - up_write(&sessions_table_lock); + up_write(&conn->session_lock); } struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn, @@ -259,9 +262,11 @@ struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn, { struct ksmbd_session *sess; + down_read(&conn->session_lock); sess = xa_load(&conn->sessions, id); if (sess) sess->last_active = jiffies; + up_read(&conn->session_lock); return sess; } @@ -271,8 +276,6 @@ struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id) down_read(&sessions_table_lock); sess = __session_lookup(id); - if (sess) - sess->last_active = jiffies; up_read(&sessions_table_lock); return sess; @@ -291,6 +294,22 @@ struct ksmbd_session *ksmbd_session_lookup_all(struct ksmbd_conn *conn, return sess; } +void ksmbd_user_session_get(struct ksmbd_session *sess) +{ + atomic_inc(&sess->refcnt); +} + +void ksmbd_user_session_put(struct ksmbd_session *sess) +{ + if (!sess) + return; + + if (atomic_read(&sess->refcnt) <= 0) + WARN_ON(1); + else + atomic_dec(&sess->refcnt); +} + struct preauth_session *ksmbd_preauth_session_alloc(struct ksmbd_conn *conn, u64 sess_id) { @@ -358,6 +377,7 @@ static struct ksmbd_session *__session_create(int protocol) xa_init(&sess->ksmbd_chann_list); INIT_LIST_HEAD(&sess->rpc_handle_list); sess->sequence_number = 1; + atomic_set(&sess->refcnt, 1); ret = __init_smb2_session(sess); if (ret) diff --git a/fs/ksmbd/mgmt/user_session.h b/fs/ksmbd/mgmt/user_session.h index 51f38e5b61abb48f235ccdc34aa4a8b96e1901f0..4dcbc1f235f09a6a24294cc4eed686db374df89b 100644 --- a/fs/ksmbd/mgmt/user_session.h +++ b/fs/ksmbd/mgmt/user_session.h @@ -60,6 +60,7 @@ struct ksmbd_session { struct ksmbd_file_table file_table; unsigned long last_active; + atomic_t refcnt; }; static inline int test_session_flag(struct ksmbd_session *sess, int bit) @@ -100,4 +101,6 @@ void ksmbd_release_tree_conn_id(struct ksmbd_session *sess, int id); int ksmbd_session_rpc_open(struct ksmbd_session *sess, char *rpc_name); void ksmbd_session_rpc_close(struct ksmbd_session *sess, int id); int ksmbd_session_rpc_method(struct ksmbd_session *sess, int id); +void ksmbd_user_session_get(struct ksmbd_session *sess); +void ksmbd_user_session_put(struct ksmbd_session *sess); #endif /* __USER_SESSION_MANAGEMENT_H__ */ diff --git a/fs/ksmbd/server.c b/fs/ksmbd/server.c index b353b3f91ce28310dcfff00b89e7b0bdadb168bc..daeda36a5eb1ebce4aa108b7d878e26f24039a8e 100644 --- a/fs/ksmbd/server.c +++ b/fs/ksmbd/server.c @@ -239,6 +239,8 @@ static void __handle_ksmbd_work(struct ksmbd_work *work, return; send: + if (work->sess) + ksmbd_user_session_put(work->sess); smb3_preauth_hash_rsp(work); if (work->sess && work->sess->enc && work->encrypted && conn->ops->encrypt_resp) { diff --git a/fs/ksmbd/smb2pdu.c b/fs/ksmbd/smb2pdu.c index 20f97b37e7c6b2335db30291d7eccf82e86e669e..829bb46f1f176460964ca38ba7a216ca4cdca43b 100644 --- a/fs/ksmbd/smb2pdu.c +++ b/fs/ksmbd/smb2pdu.c @@ -609,8 +609,10 @@ int smb2_check_user_session(struct ksmbd_work *work) /* Check for validity of user session */ work->sess = ksmbd_session_lookup_all(conn, sess_id); - if (work->sess) + if (work->sess) { + ksmbd_user_session_get(work->sess); return 1; + } ksmbd_debug(SMB, "Invalid user session, Uid %llu\n", sess_id); return -ENOENT; } @@ -1760,6 +1762,7 @@ int smb2_sess_setup(struct ksmbd_work *work) } conn->binding = true; + ksmbd_user_session_get(sess); } else if ((conn->dialect < SMB30_PROT_ID || server_conf.flags & KSMBD_GLOBAL_FLAG_SMB3_MULTICHANNEL) && (req->Flags & SMB2_SESSION_REQ_FLAG_BINDING)) { @@ -1786,6 +1789,7 @@ int smb2_sess_setup(struct ksmbd_work *work) } conn->binding = false; + ksmbd_user_session_get(sess); } work->sess = sess; @@ -2181,7 +2185,9 @@ int smb2_session_logoff(struct ksmbd_work *work) } ksmbd_destroy_file_table(&sess->file_table); + down_write(&conn->session_lock); sess->state = SMB2_SESSION_EXPIRED; + up_write(&conn->session_lock); ksmbd_free_user(sess->user); sess->user = NULL;