代码拉取完成,页面将自动刷新
import time
"""
模型配置参数
"""
class Config(object):
def __init__(self, model_name, device):
"""(model:模型名称, device:使用设备)"""
"""gpu设置"""
self.device = device
"""位置编码设置"""
# 是否加入位置编码
self.pos_enc = False
# 位置编码维度设置
self.pos_enc_dim = 8
"""读取数据设置"""
# 是否为是碱基对的添加边
self.add_edge_for_paired_nodes = True
# 是否添加密码子
self.add_codon_nodes = True
# 是否使用循环类型作为输入
self.add_loop_type = True
"""模型参数设置"""
self.in_dim_node = 4 # 每个节点的类型数
self.in_dim_loop = 7 # 循环类型种类数
self.in_dim_edge = 4 # 每个边的类型数
self.n_classes = 3 # 分类个数
if self.add_codon_nodes == True: # 若使用密码子,则输入加1
self.in_dim_node = 5
self.in_dim_loop = 8
self.hidden_dim = 64 # 隐藏层维度
self.out_dim = 64 # 输出层维度
if model_name == 'GAT':
self.n_heads = 8 # 注意力头的个数
if model_name == 'GatedGCN':
self.out_dim = self.hidden_dim # 输出层维度
self.loss_type = "MCRMSE" # 使用损失函数类型
self.in_feat_dropout = 0.0 # 嵌入层dropout系数
self.dropout = 0.1 # dropout参数
self.n_layers = 6 # 卷积层数
self.batch_norm = True
self.residual = True
self.device = device
self.readout = "mean" # 图分类readout参数
"""训练相关设置"""
# 随机种子, 固定种子保证结果一致
self.seed = 123
# 初始学习率
self.init_lr = 0.001
# weight_decay值
self.weight_decay = 0.0001
# self.weight_decay = 0.0001
# epoch迭代次数
self.num_epoch = 100
# 每个batch大小
self.batch_size = 16
# 分类个数
self.num_class = 3
# 是否使用变化lr,(训练过程中lr下降)
self.lr_decay = False
self.lr_reduce_factor = 0.5 # 相关参数
self.lr_schedule_patience = 10
self.min_lr = 1e-5
"""保存模型参数设置"""
self.save_model = False
self.model_path = "./checkpoints/" + model_name + "/"
"""log文件所在位置"""
self.need_log = True
# 获取当前时间
cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
self.log_file = "./log/" + model_name + "/" + cur_time + ".log"
"""数据所在位置设置"""
# 训练数据所在位置
self.train_file = "./dataset/train.json"
# 验证数据所在位置
self.val_file = "./dataset/valid.json"
# 测试数据所在位置
self.test_file = "./dataset/valid.json"
if self.train_file == self.val_file:
print("train_file == val_file !!!!!!")
print("训练集和测试集使用了同一个文件————调试模式")
# if self.val_file == self.test_file:
# print("val_file == test_file !")
"""
获取模型信息
"""
def get_model_info(model_name, config):
model_info = "\n"
model_info += "-" * 42
model_info += "\n模型名称: " + model_name
model_info += "\n使用碱基对: " + str(config.add_edge_for_paired_nodes)
model_info += "\n使用密码子: " + str(config.add_codon_nodes)
model_info += "\n使用循环类型: " + str(config.add_loop_type)
model_info += "\n使用位置编码: " + str(config.pos_enc)
model_info += "\t位置编码维度: " + str(config.pos_enc_dim)
model_info += "\ninit_lr: " + str(config.init_lr) + "\tweight_decay: " + str(config.weight_decay)
model_info += "\nnum_epoch: " + str(config.num_epoch) + "\tbatch_size: " + str(config.batch_size)
model_info += "\n使用损失函数类型: " + str(config.loss_type)
model_info += "\n"
model_info += "-" * 42
model_info += "\n输入节点维度: " + str(config.in_dim_node) + \
"\t输入边维度: " + str(config.in_dim_edge) + \
"\t分类个数: " + str(config.n_classes)
model_info += "\n隐藏层维度: " + str(config.hidden_dim) + \
"\t输出层维度: " + str(config.out_dim) + \
"\t网络层数: " + str(config.n_layers)
if model_name == "GAT":
model_info += "\tn_heads: " + str(config.n_heads)
model_info += "\nin_feat_dropout: " + str(config.in_feat_dropout) + \
"\tdropout: " + str(config.dropout)
model_info += "\nbatch_norm: " + str(config.batch_norm) + \
"\tresidual: " + str(config.residual)
model_info += "\n"
model_info += "-" * 42
return model_info
if __name__ == '__main__':
model_name = 'GatedGCN'
config = Config(model_name=model_name, device='GPU')
model_info = get_model_info(model_name, config)
print(model_info)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。