1 Star 0 Fork 0

MuJieShan/bert

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 1.61 KB
一键复制 编辑 原始数据 按行查看 历史
MuJieShan 提交于 2023-11-06 20:27 . Initial commit
from strategies import pruner
from utils import *
from dataloader import get_dataloader
def main():
config = init_config()
print(config)
seed_torch(config.seed)
# Config Settings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_checkpoint = "bert-base-uncased"
task = config.dataset
batch_size = config.batchsize
steps = config.epoch
lr = config.learning_rate
# Load DataLoader
print(f"\nLoading data...")
train_epoch_iterator, eval_epoch_iterator = get_dataloader(task, model_checkpoint, batch_size=batch_size)
# Load Pre-trained Model
print(f"\nLoading pre-trained BERT model \"{model_checkpoint}\"")
model=load_model(model_checkpoint,task,device)
# Define optimizer and lr_scheduler
optimizer = create_optimizer(model, learning_rate=lr)
LR_scheduler = create_scheduler(optimizer, num_training_steps=len(train_epoch_iterator) * steps)
print('train data len:', len(train_epoch_iterator))
if config.batchsize != config.prune_batchsize:
prune_iterator = iter(get_dataloader(task, model_checkpoint, batch_size=config.prune_batch_size,only_train=True))
else:
prune_iterator=train_epoch_iterator
pruning = pruner(model, compression=config.target_ratio, dataset_name=task, data_iterator=prune_iterator,config=config, optimizer=optimizer,device=device)
if config.target_ratio != 0.0:
print('pruning---')
pruning.apply()
print('training---')
train_eval_loop(config,model,train_epoch_iterator,eval_epoch_iterator,optimizer,pruning,device)
if __name__ == "__main__":
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mujieshan/bert.git
git@gitee.com:mujieshan/bert.git
mujieshan
bert
bert
master

搜索帮助