开源项目在保证原有结构不变的情况下,可采用替换相关API接口的方式将项目由GPU >> NPU >> Rec SDK。在模型迁移适配过程中可能因兼容性问题而导致模型迁移失败,此处提供另一种模型适配方案。
wangzhen38 committed on May 19, 2022 hash值(提交ID):54b6a96abae574a04c7fcac53df190f885970c3f
commit的链接: https://github.com/PaddlePaddle/PaddleRec/commit/54b6a96abae574a04c7fcac53df190f885970c3f
https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/mmoe
Census-Income-KDD数据集:
https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip
数据预处理文件:census.py。
python census.py --train_data_path train_data_path --test_data_path test_data_path --output_path output_path
参数说明:
调用census.py
文件中的get_fea_map()
方法,以{'C1':{}, 'C2':{},..., 'I1':{},...}
形式储存sparse_feature去重后的特征映射。
# get feature_map
feature_map = get_fea_map(split_file_list=list(file_path_dict.values()))
通过如下操作将原始的字符串数据映射为0~max的int64数据。
# sparse feature: mapping
for col in sparse_features:
try:
data_df[col] = data_df[col].map(lambda x: feature_map[col][x])
except KeyError as e:
raise KeyError("Feature {} not found in dataset".format(col)) from e
调用census.py
文件中的convert_input2tfrd()
方法将txt文件转换为tfrecord文件。
# txt to tfrecords
convert_input2tfrd(data_frame=data_df, in_file_path=file_path, out_file_path=output_path_)
参考Rec SDK的README.md
文件在NPU服务器上配置环境并安装镜像创建容器后,可参考DLRM模型运行命令启动模型训练。模型运行脚本是run.sh,运行此脚本需要四个参数:so_path、rec_package_path、hccl_cfg_json以及dlrm_criteo_data_path。其中,
运行Rec SDK有两种方式,一种是使用hccl配置文件(rank table方案),一种是不使用hccl配置文件(去rank table方案)。
bash run.sh {so_path} {rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path}
bash run.sh {so_path} {rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} {IP}
如:bash run.sh /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc/ /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/ hccl_json_8p.json /dataset 10.10.10.10。
注意: 去rank table方案,当前路径下不存在hccl文件,模型仍可正常运行。
迁移思路: 在现有已适配好的dlrm模型框架下,改动相关代码逻辑,完成mmoe模型的适配。核心:根据开源项目model代码修改model.py
;数据处理操作一部分放入census.py
,一部分放入main_mxrec.py
中make_batch_and_iterator()
内;main_mxrec.py
中其他相关代码改动主要是为了适配Rec SDK提供的相关特性。
详细改动见https://gitee.com/ascend/RecSDK/pulls/231/commits,Commits ID:e769a53cff0a7241aea5959434cee8518a22a7be。
下文所提到的动态扩容
、动态shape
、自动改图
、一表多查
是Rec SDK提供的相关特性,开关选项见run.sh
。
# run.sh: 32~37行
export USE_DYNAMIC=0 # 0:静态shape;1:动态shape
export CACHE_MODE="HBM" # HBM;DDR;SSD
export USE_FAAE=0 # 0:关闭准入淘汰;1:开启准入淘汰
export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容
export USE_MULTI_LOOKUP=0 # 0:一表一查;1:一表多查
export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式
迁移说明: 迁移过程中未使用gradient_descent_w.py
、mean_auc.py
。
实验超参数配置如下:取消动态学习率逻辑,学习率固定为0.001。
# 88~89行
lr_sparse = self.base_lr_sparse * lr_factor_constant
lr_dense = self.base_lr_dense * lr_factor_constant
# 140~146行
_lr_scheduler = LearningRateScheduler(
0.001,
0.001
)
# 超参数
self.batch_size = 32
self.line_per_sample = 1
self.train_epoch = 100
self.test_epoch = 100
self.expert_num = 8 #专家网络数量
self.gate_num = 2 #门控网络数量
self.expert_size = 16
self.tower_size = 8
# emb dim大小计算
self.emb_dim = self.expert_num * self.expert_size + self.gate_num * self.expert_num
迁移过程中,model.py
需参考开源项目文件models/multitask/mmoe/net.py
的代码逻辑,使用tensorflow的低阶API重新编写。输出参数必须包括loss
,prediction
,label
,trainable_variables
。迁移重点:Rec SDK对推荐模型中sparse_feature的创表查表操作作了加速,使用create_table
与sparse_lookup
接口替换tensorflow中的tf.nn.embedding_lookup
接口。 因此在适配开源项目时,会将sparse_feature的embedding操作放在模型结构外。
reclearn开源项目原始代码:
# net.py
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class MMoELayer(nn.Layer):
def __init__(self, feature_size, expert_num, expert_size, tower_size,
gate_num):
super(MMoELayer, self).__init__()
self.expert_num = expert_num
self.expert_size = expert_size
self.tower_size = tower_size
self.gate_num = gate_num
self._param_expert = []
expert_init = [pow(10, -i) for i in range(1, self.expert_num + 1)]
for i in range(0, self.expert_num):
linear = self.add_sublayer(
name='expert_' + str(i),
sublayer=nn.Linear(
feature_size,
expert_size,
#initialize each expert respectly
weight_attr=nn.initializer.Constant(value=expert_init[i]),
bias_attr=nn.initializer.Constant(value=0.1),
#bias_attr=paddle.ParamAttr(learning_rate=1.0),
name='expert_' + str(i)))
self._param_expert.append(linear)
self._param_gate = []
self._param_tower = []
self._param_tower_out = []
gate_init = [pow(10, -i) for i in range(1, self.gate_num + 1)]
for i in range(0, self.gate_num):
linear = self.add_sublayer(
name='gate_' + str(i),
sublayer=nn.Linear(
feature_size,
expert_num,
#initialize every gate respectly
weight_attr=nn.initializer.Constant(value=gate_init[i]),
bias_attr=nn.initializer.Constant(value=0.1),
#bias_attr=paddle.ParamAttr(learning_rate=1.0),
name='gate_' + str(i)))
self._param_gate.append(linear)
linear = self.add_sublayer(
name='tower_' + str(i),
sublayer=nn.Linear(
expert_size,
tower_size,
#initialize each gate respectly
weight_attr=nn.initializer.Constant(value=gate_init[i]),
bias_attr=nn.initializer.Constant(value=0.1),
#bias_attr=paddle.ParamAttr(learning_rate=1.0),
name='tower_' + str(i)))
self._param_tower.append(linear)
linear = self.add_sublayer(
name='tower_out_' + str(i),
sublayer=nn.Linear(
tower_size,
2,
#initialize each gate respectly
weight_attr=nn.initializer.Constant(value=gate_init[i]),
bias_attr=nn.initializer.Constant(value=0.1),
name='tower_out_' + str(i)))
self._param_tower_out.append(linear)
def forward(self, input_data):
expert_outputs = []
for i in range(0, self.expert_num):
linear_out = self._param_expert[i](input_data)
expert_output = F.relu(linear_out)
expert_outputs.append(expert_output)
expert_concat = paddle.concat(x=expert_outputs, axis=1)
expert_concat = paddle.reshape(
expert_concat, [-1, self.expert_num, self.expert_size])
output_layers = []
for i in range(0, self.gate_num):
cur_gate_linear = self._param_gate[i](input_data)
cur_gate = F.softmax(cur_gate_linear)
cur_gate = paddle.reshape(cur_gate, [-1, self.expert_num, 1])
cur_gate_expert = paddle.multiply(x=expert_concat, y=cur_gate)
cur_gate_expert = paddle.sum(x=cur_gate_expert, axis=1)
cur_tower = self._param_tower[i](cur_gate_expert)
cur_tower = F.relu(cur_tower)
out = self._param_tower_out[i](cur_tower)
out = F.softmax(out)
out = paddle.clip(out, min=1e-15, max=1.0 - 1e-15)
output_layers.append(out)
return output_layers
_input
未把sparse特征和dense特征区分,迁移后将sparse数据单独做了处理,详见main_mxrec.py
,而后将sparse查表数据embedding与dense数据分别传入模型中,之后再切片聚合,还原原代码操作。
迁移后代码:
# model.py
from easydict import EasyDict as edict
import tensorflow as tf
model_cfg = edict()
model_cfg.loss_mode = "batch"
LOSS_OP_NAME = "loss"
LABEL_OP_NAME = "label"
VAR_LIST = "variable"
PRED_OP_NAME = "pred"
class MyModel:
def __init__(self, expert_num=8, expert_size=16, tower_size=8, gate_num=2):
self.expert_num = expert_num
self.expert_size = expert_size
self.tower_size = tower_size
self.gate_num = gate_num
def expert_layer(self, _input):
param_expert = []
for i in range(0, self.expert_num):
expert_linear = tf.layers.dense(_input, units=self.expert_size, activation=None, name=f'expert_layer_{i}',
kernel_initializer=tf.constant_initializer(value=0.1),
bias_initializer=tf.constant_initializer(value=0.1))
param_expert.append(expert_linear)
return param_expert
def gate_layer(self, _input):
param_gate = []
for i in range(0, self.gate_num):
gate_linear = tf.layers.dense(_input, units=self.expert_num, activation=None, name=f'gate_layer_{i}',
kernel_initializer=tf.constant_initializer(value=0.1),
bias_initializer=tf.constant_initializer(value=0.1))
param_gate.append(gate_linear)
return param_gate
def tower_layer(self, _input, layer_name):
tower_linear = tf.layers.dense(_input, units=self.tower_size, activation='relu',
name=f'tower_layer_{layer_name}',
kernel_initializer=tf.constant_initializer(value=0.1),
bias_initializer=tf.constant_initializer(value=0.1))
tower_linear_out = tf.layers.dense(tower_linear, units=2, activation=None,
name=f'tower_payer_out_{layer_name}',
kernel_initializer=tf.constant_initializer(value=0.1),
bias_initializer=tf.constant_initializer(value=0.1))
return tower_linear_out
def build_model(self,
embedding=None,
dense_feature=None,
label=None,
is_training=True,
seed=None):
with tf.variable_scope("mmoe", reuse=tf.AUTO_REUSE):
dense_expert = self.expert_layer(dense_feature)
dense_gate = self.gate_layer(dense_feature)
all_expert = []
_slice_num = 0
for i in range(0, self.expert_num):
slice_num_end = _slice_num + self.expert_size
cur_expert = tf.add(dense_expert[i], embedding[:, _slice_num:slice_num_end])
cur_expert = tf.nn.relu(cur_expert)
all_expert.append(cur_expert)
_slice_num = slice_num_end
expert_concat = tf.concat(all_expert, axis=1)
expert_concat = tf.reshape(expert_concat, [-1, self.expert_num, self.expert_size])
output_layers = []
out_pred = []
for i in range(0, self.gate_num):
slice_gate_end = _slice_num + self.expert_num
cur_gate = tf.add(dense_gate[i], embedding[:, _slice_num:slice_gate_end])
cur_gate = tf.nn.softmax(cur_gate)
cur_gate = tf.reshape(cur_gate, [-1, self.expert_num, 1])
cur_gate_expert = tf.multiply(x=expert_concat, y=cur_gate)
cur_gate_expert = tf.reduce_sum(cur_gate_expert, axis=1)
out = self.tower_layer(cur_gate_expert, i)
out = tf.nn.softmax(out)
out = tf.clip_by_value(out, clip_value_min=1e-15, clip_value_max=1.0 - 1e-15)
output_layers.append(out)
out_pred.append(tf.nn.softmax(out[:, 1]))
_slice_num = slice_gate_end
trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='mmoe')
label_income = label[:, 0:1]
label_mat = label[:, 1:]
pred_income_1 = tf.slice(output_layers[0], [0, 1], [-1, 1])
pred_marital_1 = tf.slice(output_layers[1], [0, 1], [-1, 1])
cost_income = tf.losses.log_loss(labels=tf.cast(label_income, tf.float32), predictions=pred_income_1,
epsilon=1e-4)
cost_marital = tf.losses.log_loss(labels=tf.cast(label_mat, tf.float32), predictions=pred_marital_1,
epsilon=1e-4)
avg_cost_income = tf.reduce_mean(cost_income)
avg_cost_marital = tf.reduce_mean(cost_marital)
loss = 0.5 * (avg_cost_income + avg_cost_marital)
return {LOSS_OP_NAME: loss,
PRED_OP_NAME: out_pred,
LABEL_OP_NAME: label,
VAR_LIST: trainable_variables}
main_mxrec.py
文件中的函数如下所示。make_batch_and_iterator()
是读取数据集以及对数据作处理的函数;model_forward()
是前向过程函数;evaluate()
与evaluate_fix()
是评估函数,用于计算测试集的AUC与loss。add_timestamp_func()
与特征准入、淘汰有关;create_feature_spec_list()
是生成元素为FeatureSpec类的列表的函数,其返回值是make_batch_and_iterator()
所需的传参。特征准入与淘汰、FeatureSpec类、自动改图等解释见Rec SDK用户指南。
add_timestamp_func()
make_batch_and_iterator()
model_forward()
evaluate()
evaluate_fix()
create_feature_spec_list()
迁移代码改动说明: add_timestamp_func()
、evaluate()
、evaluate_fix()
未作修改。
3.1 读取数据集:make_batch_and_iterator()
# main_mxrec.py:65~79行
def extract_fn(data_record):
features = {
# Extract features using the keys set during creation
'label': tf.compat.v1.FixedLenFeature(shape=(2 * config.line_per_sample,), dtype=tf.int64),
'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(29 * config.line_per_sample,), dtype=tf.int64),
'dense_feature': tf.compat.v1.FixedLenFeature(shape=(11 * config.line_per_sample,), dtype=tf.float32),
}
sample = tf.compat.v1.parse_single_example(data_record, features)
return sample
def reshape_fn(batch):
batch['label'] = tf.reshape(batch['label'], [-1, 2])
batch['dense_feature'] = tf.reshape(batch['dense_feature'], [-1, 11])
batch['sparse_feature'] = tf.reshape(batch['sparse_feature'], [-1, 29])
return batch
3.2 模型前向传播过程
# main_mxrec.py:112~137行
def model_forward(feature_list, hash_table_list, batch, is_train, modify_graph):
embedding_list = []
logger.debug(f"In model_forward function, is_train: {is_train}, feature_list: {len(feature_list)}, "
f"hash_table_list: {len(hash_table_list)}")
for feature, hash_table in zip(feature_list, hash_table_list):
if MODIFY_GRAPH_FLAG:
feature = batch["sparse_feature"]
embedding = sparse_lookup(hash_table, feature, cfg.send_count, dim=None, is_train=is_train,
name="user_embedding_lookup", modify_graph=modify_graph, batch=batch,
access_and_evict_config=None)
embedding_list.append(embedding)
if len(embedding_list) == 1:
emb = embedding_list[0]
elif len(embedding_list) > 1:
emb = tf.reduce_sum(embedding_list, axis=0, keepdims=False)
else:
raise ValueError("the length of embedding_list must be greater than or equal to 1.")
emb = tf.reduce_sum(emb, axis=1)
my_model = MyModel()
model_output = my_model.build_model(embedding=emb,
dense_feature=batch["dense_feature"],
label=batch["label"],
is_training=is_train,
seed=dense_hashtable_seed)
return model_output
该函数是前向传播函数,主要包括sparse_feature的embedding操作(查表)与model前向操作。
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。