Ai
107 Star 891 Fork 1.4K

MindSpore/models
暂停

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 5.14 KB
一键复制 编辑 原始数据 按行查看 历史
Shawny 提交于 2022-04-15 15:02 +08:00 . turn on data sink mode in skipgram
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get word2vec embeddings by running trian.py.
python train.py --device_target=[DEVICE_TARGET]
"""
import argparse
import ast
import os
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common import set_seed
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.model import Model
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from src.config import w2v_cfg
from src.dataset import DataController
from src.lr_scheduler import poly_decay_lr
from src.skipgram import SkipGram
parser = argparse.ArgumentParser(description='Train SkipGram')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device target, support Ascend and GPU.')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend.')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='run distribute.')
parser.add_argument('--pre_trained', type=str, default=None, help='the pretrained checkpoint file path.')
parser.add_argument('--train_data_dir', type=str, default=None, help='the directory of train data.')
args = parser.parse_args()
set_seed(1)
if __name__ == '__main__':
if not os.path.exists(w2v_cfg.temp_dir):
os.mkdir(w2v_cfg.temp_dir)
if not os.path.exists(w2v_cfg.ckpt_dir):
os.mkdir(w2v_cfg.ckpt_dir)
print("Set Context...")
rank_size = int(os.getenv('RANK_SIZE')) if args.run_distribute else 1
rank_id = int(os.getenv('RANK_ID')) if args.run_distribute else 0
context.set_context(mode=context.GRAPH_MODE,
device_target=args.device_target,
device_id=args.device_id,
save_graphs=False)
if args.run_distribute:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
print('Done.')
print("Get Mindrecord...")
train_data_dir = w2v_cfg.train_data_dir
if args.train_data_dir:
train_data_dir = args.train_data_dir
data_controller = DataController(train_data_dir, w2v_cfg.ms_dir, w2v_cfg.min_count, w2v_cfg.window_size,
w2v_cfg.neg_sample_num, w2v_cfg.data_epoch, w2v_cfg.batch_size,
rank_size, rank_id)
dataset = data_controller.get_mindrecord_dataset(col_list=['c_words', 'p_words', 'n_words'])
print('Done.')
print("Configure Training Parameters...")
config_ck = CheckpointConfig(save_checkpoint_steps=w2v_cfg.save_checkpoint_steps,
keep_checkpoint_max=w2v_cfg.keep_checkpoint_max)
ckpoint = ModelCheckpoint(prefix="w2v", directory=w2v_cfg.ckpt_dir, config=config_ck)
loss_monitor = LossMonitor(w2v_cfg.print_interval)
time_monitor = TimeMonitor()
total_step = dataset.get_dataset_size() * w2v_cfg.train_epoch
print('Total Step:', total_step)
decay_step = min(total_step, int(2.4e6) // rank_size)
lrs = Tensor(poly_decay_lr(w2v_cfg.lr, w2v_cfg.end_lr, decay_step, total_step, w2v_cfg.power,
update_decay_step=False))
callbacks = [loss_monitor, time_monitor]
if rank_id == 0:
callbacks = [loss_monitor, time_monitor, ckpoint]
net = SkipGram(data_controller.get_vocabs_size(), w2v_cfg.emb_size)
if args.pre_trained:
load_param_into_net(net, load_checkpoint(args.pre_trained))
optim = nn.Adam(net.trainable_params(), learning_rate=lrs)
train_net = nn.TrainOneStepCell(network=net, optimizer=optim)
model = Model(train_net)
print('Done.')
print("Train Model...")
if w2v_cfg.dataset_sink_mode:
epoch_num = int(w2v_cfg.train_epoch * dataset.get_dataset_size() / w2v_cfg.print_interval)
else:
epoch_num = w2v_cfg.train_epoch
model.train(epoch=epoch_num, train_dataset=dataset,
callbacks=callbacks, dataset_sink_mode=w2v_cfg.dataset_sink_mode, sink_size=w2v_cfg.print_interval)
print('Done.')
print("Save Word2Vec Embedding...")
net.save_w2v_emb(w2v_cfg.w2v_emb_save_dir, data_controller.id2word) # save word2vec embedding
print('Done.')
print("End.")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
master

搜索帮助