1 Star 2 Fork 0

ThE_eXpLoReR/mRNA_II

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Config.py 4.99 KB
一键复制 编辑 原始数据 按行查看 历史
ThE_eXpLoReR 提交于 2022-04-09 15:11 . Initial commit
import time
"""
模型配置参数
"""
class Config(object):
def __init__(self, model_name, device):
"""(model:模型名称, device:使用设备)"""
"""gpu设置"""
self.device = device
"""位置编码设置"""
# 是否加入位置编码
self.pos_enc = False
# 位置编码维度设置
self.pos_enc_dim = 8
"""读取数据设置"""
# 是否为是碱基对的添加边
self.add_edge_for_paired_nodes = True
# 是否添加密码子
self.add_codon_nodes = True
# 是否使用循环类型作为输入
self.add_loop_type = True
"""模型参数设置"""
self.in_dim_node = 4 # 每个节点的类型数
self.in_dim_loop = 7 # 循环类型种类数
self.in_dim_edge = 4 # 每个边的类型数
self.n_classes = 3 # 分类个数
if self.add_codon_nodes == True: # 若使用密码子,则输入加1
self.in_dim_node = 5
self.in_dim_loop = 8
self.hidden_dim = 64 # 隐藏层维度
self.out_dim = 64 # 输出层维度
if model_name == 'GAT':
self.n_heads = 8 # 注意力头的个数
if model_name == 'GatedGCN':
self.out_dim = self.hidden_dim # 输出层维度
self.loss_type = "MCRMSE" # 使用损失函数类型
self.in_feat_dropout = 0.0 # 嵌入层dropout系数
self.dropout = 0.1 # dropout参数
self.n_layers = 6 # 卷积层数
self.batch_norm = True
self.residual = True
self.device = device
self.readout = "mean" # 图分类readout参数
"""训练相关设置"""
# 随机种子, 固定种子保证结果一致
self.seed = 123
# 初始学习率
self.init_lr = 0.001
# weight_decay值
self.weight_decay = 0.0001
# self.weight_decay = 0.0001
# epoch迭代次数
self.num_epoch = 100
# 每个batch大小
self.batch_size = 16
# 分类个数
self.num_class = 3
# 是否使用变化lr,(训练过程中lr下降)
self.lr_decay = False
self.lr_reduce_factor = 0.5 # 相关参数
self.lr_schedule_patience = 10
self.min_lr = 1e-5
"""保存模型参数设置"""
self.save_model = False
self.model_path = "./checkpoints/" + model_name + "/"
"""log文件所在位置"""
self.need_log = True
# 获取当前时间
cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
self.log_file = "./log/" + model_name + "/" + cur_time + ".log"
"""数据所在位置设置"""
# 训练数据所在位置
self.train_file = "./dataset/train.json"
# 验证数据所在位置
self.val_file = "./dataset/valid.json"
# 测试数据所在位置
self.test_file = "./dataset/valid.json"
if self.train_file == self.val_file:
print("train_file == val_file !!!!!!")
print("训练集和测试集使用了同一个文件————调试模式")
# if self.val_file == self.test_file:
# print("val_file == test_file !")
"""
获取模型信息
"""
def get_model_info(model_name, config):
model_info = "\n"
model_info += "-" * 42
model_info += "\n模型名称: " + model_name
model_info += "\n使用碱基对: " + str(config.add_edge_for_paired_nodes)
model_info += "\n使用密码子: " + str(config.add_codon_nodes)
model_info += "\n使用循环类型: " + str(config.add_loop_type)
model_info += "\n使用位置编码: " + str(config.pos_enc)
model_info += "\t位置编码维度: " + str(config.pos_enc_dim)
model_info += "\ninit_lr: " + str(config.init_lr) + "\tweight_decay: " + str(config.weight_decay)
model_info += "\nnum_epoch: " + str(config.num_epoch) + "\tbatch_size: " + str(config.batch_size)
model_info += "\n使用损失函数类型: " + str(config.loss_type)
model_info += "\n"
model_info += "-" * 42
model_info += "\n输入节点维度: " + str(config.in_dim_node) + \
"\t输入边维度: " + str(config.in_dim_edge) + \
"\t分类个数: " + str(config.n_classes)
model_info += "\n隐藏层维度: " + str(config.hidden_dim) + \
"\t输出层维度: " + str(config.out_dim) + \
"\t网络层数: " + str(config.n_layers)
if model_name == "GAT":
model_info += "\tn_heads: " + str(config.n_heads)
model_info += "\nin_feat_dropout: " + str(config.in_feat_dropout) + \
"\tdropout: " + str(config.dropout)
model_info += "\nbatch_norm: " + str(config.batch_norm) + \
"\tresidual: " + str(config.residual)
model_info += "\n"
model_info += "-" * 42
return model_info
if __name__ == '__main__':
model_name = 'GatedGCN'
config = Config(model_name=model_name, device='GPU')
model_info = get_model_info(model_name, config)
print(model_info)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/wang567/mRNA_II.git
git@gitee.com:wang567/mRNA_II.git
wang567
mRNA_II
mRNA_II
master

搜索帮助