diff --git a/scGen/README.md b/scGen/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fb4845b802d61da3909b44be18a5bf11ff1b7ab9 --- /dev/null +++ b/scGen/README.md @@ -0,0 +1,146 @@ +# 目录 + +- [目录](#目录) +- [scGen描述](#scGen描述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本和样例代码](#脚本和样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [推理过程](#推理过程) + - [用法](#用法) + - [结果](#结果) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#modelzoo主页) + +# scGen描述 + +scGen是一种生成模型,用于预测不同细胞类型、研究和物种中的单细胞扰动响应(发表于《Nature Methods》,2019年)。 + +# 模型架构 + +scGEN 的模型架构基于变分自编码器(VAE),包括一个编码器和一个解码器。编码器将单细胞基因表达数据映射到一个潜在空间,生成细胞的隐变量表示;解码器则从潜在空间生成新的基因表达数据。SCGEN 通过在潜在空间中施加条件编码,实现对不同条件下细胞状态的生成与转换。 + +# 数据集 + +使用的数据集:[pancreas](https://drive.google.com/drive/folders/1v3qySFECxtqWLRhRTSbfQDFqdUCAXql3) + +- 名称: Panglao scRNA-seq数据集 +- 格式: H5AD文件 +- 路径: ./data/pancreas.h5ad +- 数据大小: 176MB + +支持的数据集:[pancreas] 或者与 AnnData 格式相同的数据集 + +- 目录结构如下,由用户定义目录和文件的名称 + +![image](demo/predict-demo.jpg) + +- 如果用户需要自定义数据集,则需要将数据集格式转化为AnnData数据格式。 + +# 环境要求 + +- 硬件(Ascend) + - 使用Ascend处理器来搭建硬件环境。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install) +- 如需查看详情,请参见如下资源 + - [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.html) + +# 快速入门 + +- 通过官方网站安装Mindspore后,您可以按照如下步骤进行训练 + +```shell +# 单卡训练 +python train_model.py +``` + +# 脚本说明 + +## 脚本和样例代码 + +```text + |----data + |----demo + |----AnnData.png + |----models + |----vae.py + |----utils + |----data_utils.py + |----model_utils.py + |----batch_correction.py + |----batch_effect_removal.py + |----README_CN.md + |----train.py + |----train_model.py +``` + +## 脚本参数 + +预训练参数(pretrain.py): +```text +--train_path 数据路径 +--model_to_use 模型路径 +--batch_size 批次大小, 默认: 32 +--X_dim 基因表达矩阵的特征维度 +--z_dim 潜在空间维度,默认值:100 +--lr 学习率,默认值:0.001 +--dr_rate dropout,默认值:0.2 +``` + +## 训练过程 + +在Ascend设备上,使用python脚本直接开始训练 + + python命令启动 + + ```shell + # 单卡训练 + python train_model.py + ``` + +```text + + Epoch [1/100], Loss: 981.0034 + Epoch [2/100], Loss: 939.3733 + Epoch [3/100], Loss: 922.7879 + Epoch [4/100], Loss: 913.1795 + Epoch [5/100], Loss: 905.8361 + Epoch [6/100], Loss: 900.6747 + +``` + +## 推理过程 + +**推理前需使用train_model.py文件生成的模型检查点文件。** + +### 用法 + +执行完整的推理脚本如下: + +```shell + +python batch_effect_removal.py + +``` + +### 结果 + +去批次结果保存在/batch_removal_data/1.h5ad中。 + +# 随机情况说明 + +在训练中存在以下随机性来源: + +1. 数据和索引的随机打乱 +2. 潜在空间中的随机噪声生成 +3. 样本生成时的噪声引入 + +# ModelZoo主页 + +请浏览官网[主页](https://gitee.com/mindspore/models)。 \ No newline at end of file diff --git a/scGen/batch_correction.py b/scGen/batch_correction.py new file mode 100644 index 0000000000000000000000000000000000000000..4e07d56e75342fccae9f911d28d8b38d06eefe36 --- /dev/null +++ b/scGen/batch_correction.py @@ -0,0 +1,65 @@ +import numpy as np +import scanpy as sc +import anndata +from utils.model_utils import give_me_latent, reconstruct + +def vector_batch_removal(model, data): + latent_all = give_me_latent(model, data.X) + latent_ann = sc.AnnData(latent_all) + latent_ann.obs["cell_type"] = data.obs["cell_type"].tolist() + latent_ann.obs["batch"] = data.obs["batch"].tolist() + latent_ann.obs["sample"] = data.obs["sample"].tolist() + + unique_cell_types = np.unique(latent_ann.obs["cell_type"]) + shared_anns = [] + not_shared_ann = [] + + for cell_type in unique_cell_types: + temp_cell = latent_ann[latent_ann.obs["cell_type"] == cell_type] + if len(np.unique(temp_cell.obs["batch"])) < 2: + cell_type_ann = latent_ann[latent_ann.obs["cell_type"] == cell_type] + not_shared_ann.append(cell_type_ann) + continue + + print(cell_type) + temp_cell = latent_ann[latent_ann.obs["cell_type"] == cell_type] + batch_list = {} + max_batch = 0 + max_batch_ind = "" + batchs = np.unique(temp_cell.obs["batch"]) + + for i in batchs: + temp = temp_cell[temp_cell.obs["batch"] == i] + if max_batch < len(temp): + max_batch = len(temp) + max_batch_ind = i + batch_list[i] = temp + + max_batch_ann = batch_list[max_batch_ind] + for study in batch_list: + delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0) + batch_list[study] = batch_list[study].copy() + batch_list[study].X = delta + batch_list[study].X + corrected = anndata.concat(list(batch_list.values())) + shared_anns.append(corrected) + + all_shared_ann = anndata.concat(shared_anns) if shared_anns else sc.AnnData() + all_not_shared_ann = anndata.concat(not_shared_ann) if not_shared_ann else sc.AnnData() + all_corrected_data = anndata.concat([all_shared_ann, all_not_shared_ann]) + + corrected_data = reconstruct(model, all_corrected_data.X, use_data=True) + corrected = sc.AnnData(corrected_data) + corrected.obs["cell_type"] = all_corrected_data.obs["cell_type"].tolist() + corrected.obs["study"] = all_corrected_data.obs["sample"].tolist() + corrected.var_names = data.var_names.tolist() + + if all_shared_ann.n_obs > 0: + corrected_shared_data = reconstruct(model, all_shared_ann.X, use_data=True) + corrected_shared = sc.AnnData(corrected_shared_data) + corrected_shared.obs["cell_type"] = all_shared_ann.obs["cell_type"].tolist() + corrected_shared.obs["study"] = all_shared_ann.obs["sample"].tolist() + corrected_shared.var_names = data.var_names.tolist() + else: + corrected_shared = sc.AnnData() + + return corrected, corrected_shared \ No newline at end of file diff --git a/scGen/batch_effect_removal.py b/scGen/batch_effect_removal.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4a6b6eb7ca96d41667ef6bc430331fa73a68ff --- /dev/null +++ b/scGen/batch_effect_removal.py @@ -0,0 +1,43 @@ +import mindspore as ms +from models.vae import VAE +from utils.data_utils import load_data +from batch_correction import vector_batch_removal +import anndata as ad + +# 设置MindSpore上下文 +ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") + +# 配置参数 +train_path = "./data/pancreas.h5ad" +model_path = "./models/scGen/scgen.pt" +output_path = "./batch_removal_data/1.h5ad" + +def main(): + # 加载数据 + data, _, _ = load_data(train_path) + gex_size = data.X.shape[1] + + # 初始化并加载已训练模型 + model = VAE(input_dim=gex_size, z_dim=100, dr_rate=0.2) + param_dict = ms.load_checkpoint(model_path) + ms.load_param_into_net(model, param_dict) + print("模型加载完毕。") + + # 批次效应去除 + all_data, shared = vector_batch_removal(model, data) + + # 后处理 + top_cell_types = all_data.obs["cell_type"].value_counts().index.tolist()[:7] + if "not applicable" in top_cell_types: + top_cell_types.remove("not applicable") + + all_data.obs["celltype"] = "others" + for cell_type in top_cell_types: + all_data.obs.loc[all_data.obs["cell_type"] == cell_type, "celltype"] = cell_type + + # 保存结果 + all_data.write(output_path) + print(f"scGen batch corrected pancreas has been saved in {output_path}") + +if __name__ == "__main__": + main() diff --git a/scGen/demo/AnnData.png b/scGen/demo/AnnData.png new file mode 100644 index 0000000000000000000000000000000000000000..e6aa27577a0d28323dcadf6deffab18faf0dbf1c Binary files /dev/null and b/scGen/demo/AnnData.png differ diff --git a/scGen/models/vae.py b/scGen/models/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..5c41fe832a89048f97e47f2a36161383cf3985ae --- /dev/null +++ b/scGen/models/vae.py @@ -0,0 +1,63 @@ +import numpy as np +import mindspore as ms +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore import Tensor, dtype as mstype + +class VAE(nn.Cell): + def __init__(self, input_dim, hidden_dim=800, z_dim=100, dr_rate=0.2): + super(VAE, self).__init__() + + # =============================== Q(z|X) ====================================== + self.encoder = nn.SequentialCell([ + nn.Dense(input_dim, hidden_dim, has_bias=False), + nn.BatchNorm1d(hidden_dim), + nn.LeakyReLU(alpha=0.01), + nn.Dropout(p=dr_rate), + nn.Dense(hidden_dim, hidden_dim, has_bias=False), + nn.BatchNorm1d(hidden_dim), + nn.LeakyReLU(), + nn.Dropout(p=dr_rate), + ]) + self.fc_mean = nn.Dense(hidden_dim, z_dim) + self.fc_var = nn.Dense(hidden_dim, z_dim) + + # =============================== P(X|z) ====================================== + self.decoder = nn.SequentialCell([ + nn.Dense(z_dim, hidden_dim, has_bias=False), + nn.BatchNorm1d(hidden_dim), + nn.LeakyReLU(alpha=0.01), + nn.Dropout(p=dr_rate), + nn.Dense(hidden_dim, hidden_dim, has_bias=False), + nn.BatchNorm1d(hidden_dim), + nn.LeakyReLU(), + nn.Dropout(p=dr_rate), + nn.Dense(hidden_dim, input_dim), + nn.ReLU(), + ]) + + self.exp = ops.Exp() + self.randn_like = ops.StandardNormal() + + def encode(self, x): + h = self.encoder(x) + mean = self.fc_mean(h) + log_var = self.fc_var(h) + return mean, log_var + + def reparameterize(self, mu, log_var): + std = self.exp(0.5 * log_var) + shape_tuple = ops.Shape()(std) + shape = Tensor(list(shape_tuple), mstype.int32) + eps = self.randn_like(shape) + return mu + eps * std + + def decode(self, z): + x_hat = self.decoder(z) + return x_hat + + def construct(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_hat = self.decode(z) + return x_hat, mu, log_var \ No newline at end of file diff --git a/scGen/train.py b/scGen/train.py new file mode 100644 index 0000000000000000000000000000000000000000..21031c382707e27bfa24f4818d1c6f136311f2b7 --- /dev/null +++ b/scGen/train.py @@ -0,0 +1,37 @@ +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor +from mindspore.train.serialization import save_checkpoint + +def train(model, optimizer, data, n_epochs, batch_size=32, model_path=None): + model.set_train() + data_size = data.shape[0] + + for epoch in range(n_epochs): + permutation = np.random.permutation(data_size) + data = data[permutation, :] + train_loss = 0 + + for i in range(0, data_size, batch_size): + batch_data_np = data[i:i + batch_size] + batch_data = Tensor(batch_data_np, ms.float32) + + def forward_fn(batch_data): + x_hat, mu, log_var = model(batch_data) + recon_loss = 0.5 * ops.reduce_sum(ops.mse_loss(x_hat, batch_data)) + kl_loss = 0.5 * ops.reduce_sum(ops.exp(log_var) + ops.square(mu) - 1 - log_var) + vae_loss = recon_loss + 0.00005 * kl_loss + return vae_loss + + grads = ops.GradOperation(get_by_list=True)(forward_fn, model.trainable_params())(batch_data) + optimizer(grads) + + vae_loss = forward_fn(batch_data) + train_loss += vae_loss.asnumpy() + + avg_loss = train_loss / data_size + print(f"Epoch [{epoch + 1}/{n_epochs}], Loss: {avg_loss:.4f}") + + if model_path: + save_checkpoint(model.trainable_params(), model_path) + print(f"模型已保存到 {model_path}") \ No newline at end of file diff --git a/scGen/train_model.py b/scGen/train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ae43924d44104af3d602fca4ae2cbb1d23e50f --- /dev/null +++ b/scGen/train_model.py @@ -0,0 +1,36 @@ +import mindspore as ms +from models.vae import VAE +from utils.data_utils import load_data +from train import train +from batch_correction import vector_batch_removal + +# 设置MindSpore上下文 +ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") + +# 配置参数 +train_path = "./data/pancreas.h5ad" +model_path = "./models/scGen/scgen.pt" +batch_size = 32 +z_dim = 100 +lr = 0.001 +dr_rate = 0.2 + +def main(): + # 加载数据 + data, train_data, input_matrix = load_data(train_path) + gex_size = input_matrix.shape[1] + + # 初始化模型 + model = VAE(input_dim=gex_size, z_dim=z_dim, dr_rate=dr_rate) + optimizer = ms.experimental.optim.Adam(model.trainable_params(), lr=lr) + + # 数据预处理 + data.obs["study"] = data.obs["sample"] + data.obs["cell_type"] = data.obs["celltype"] + + # 训练模型 + train(model, optimizer, train_data, n_epochs=100, model_path=model_path) + print(f"模型已保存到 {model_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scGen/utils/data_utils.py b/scGen/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..284a0f81b1f8678efbe795895938709d87048c01 --- /dev/null +++ b/scGen/utils/data_utils.py @@ -0,0 +1,12 @@ +import anndata +import numpy as np +import scanpy as sc +from random import shuffle + +def load_data(train_path): + data = sc.read(train_path) + input_matrix = data.X + ind_list = [i for i in range(input_matrix.shape[0])] + shuffle(ind_list) + train_data = input_matrix[ind_list, :] + return data, train_data, input_matrix \ No newline at end of file diff --git a/scGen/utils/model_utils.py b/scGen/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf7ce24d32d7020bbcea5b77eb6bce554bdfd479 --- /dev/null +++ b/scGen/utils/model_utils.py @@ -0,0 +1,31 @@ +import numpy as np +from mindspore import Tensor +import mindspore as ms + +def give_me_latent(model, data): + model.set_train(False) + data_tensor = Tensor(data, dtype=ms.float32) + mu = model.encode(data_tensor) + return mu.asnumpy() + +def avg_vector(model, data): + latent = give_me_latent(model, data) + arithmatic = np.average(latent, axis=0) + return arithmatic + +def reconstruct(model, data, use_data=False): + model.set_train(False) + if use_data: + latent_tensor = Tensor(data, dtype=ms.float32) + else: + latent_np = give_me_latent(model, data) + latent_tensor = Tensor(latent_np, dtype=ms.float32) + reconstructed_tensor = model.decode(latent_tensor) + return reconstructed_tensor.asnumpy() + +def sample(model, n_sample, z_dim): + model.set_train(False) + noise_np = np.random.randn(n_sample, z_dim).astype(np.float32) + noise_tensor = Tensor(noise_np) + gen_cells_tensor = model.decode(noise_tensor) + return gen_cells_tensor.asnumpy() \ No newline at end of file