代码拉取完成,页面将自动刷新
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get word2vec embeddings by running trian.py.
python train.py --device_target=[DEVICE_TARGET]
"""
import argparse
import ast
import os
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common import set_seed
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.model import Model
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from src.config import w2v_cfg
from src.dataset import DataController
from src.lr_scheduler import poly_decay_lr
from src.skipgram import SkipGram
parser = argparse.ArgumentParser(description='Train SkipGram')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device target, support Ascend and GPU.')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend.')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='run distribute.')
parser.add_argument('--pre_trained', type=str, default=None, help='the pretrained checkpoint file path.')
parser.add_argument('--train_data_dir', type=str, default=None, help='the directory of train data.')
args = parser.parse_args()
set_seed(1)
if __name__ == '__main__':
if not os.path.exists(w2v_cfg.temp_dir):
os.mkdir(w2v_cfg.temp_dir)
if not os.path.exists(w2v_cfg.ckpt_dir):
os.mkdir(w2v_cfg.ckpt_dir)
print("Set Context...")
rank_size = int(os.getenv('RANK_SIZE')) if args.run_distribute else 1
rank_id = int(os.getenv('RANK_ID')) if args.run_distribute else 0
context.set_context(mode=context.GRAPH_MODE,
device_target=args.device_target,
device_id=args.device_id,
save_graphs=False)
if args.run_distribute:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
print('Done.')
print("Get Mindrecord...")
train_data_dir = w2v_cfg.train_data_dir
if args.train_data_dir:
train_data_dir = args.train_data_dir
data_controller = DataController(train_data_dir, w2v_cfg.ms_dir, w2v_cfg.min_count, w2v_cfg.window_size,
w2v_cfg.neg_sample_num, w2v_cfg.data_epoch, w2v_cfg.batch_size,
rank_size, rank_id)
dataset = data_controller.get_mindrecord_dataset(col_list=['c_words', 'p_words', 'n_words'])
print('Done.')
print("Configure Training Parameters...")
config_ck = CheckpointConfig(save_checkpoint_steps=w2v_cfg.save_checkpoint_steps,
keep_checkpoint_max=w2v_cfg.keep_checkpoint_max)
ckpoint = ModelCheckpoint(prefix="w2v", directory=w2v_cfg.ckpt_dir, config=config_ck)
loss_monitor = LossMonitor(w2v_cfg.print_interval)
time_monitor = TimeMonitor()
total_step = dataset.get_dataset_size() * w2v_cfg.train_epoch
print('Total Step:', total_step)
decay_step = min(total_step, int(2.4e6) // rank_size)
lrs = Tensor(poly_decay_lr(w2v_cfg.lr, w2v_cfg.end_lr, decay_step, total_step, w2v_cfg.power,
update_decay_step=False))
callbacks = [loss_monitor, time_monitor]
if rank_id == 0:
callbacks = [loss_monitor, time_monitor, ckpoint]
net = SkipGram(data_controller.get_vocabs_size(), w2v_cfg.emb_size)
if args.pre_trained:
load_param_into_net(net, load_checkpoint(args.pre_trained))
optim = nn.Adam(net.trainable_params(), learning_rate=lrs)
train_net = nn.TrainOneStepCell(network=net, optimizer=optim)
model = Model(train_net)
print('Done.')
print("Train Model...")
if w2v_cfg.dataset_sink_mode:
epoch_num = int(w2v_cfg.train_epoch * dataset.get_dataset_size() / w2v_cfg.print_interval)
else:
epoch_num = w2v_cfg.train_epoch
model.train(epoch=epoch_num, train_dataset=dataset,
callbacks=callbacks, dataset_sink_mode=w2v_cfg.dataset_sink_mode, sink_size=w2v_cfg.print_interval)
print('Done.')
print("Save Word2Vec Embedding...")
net.save_w2v_emb(w2v_cfg.w2v_emb_save_dir, data_controller.id2word) # save word2vec embedding
print('Done.')
print("End.")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。