diff --git a/official/nlp/Pangu_alpha/src/adam.py b/official/nlp/Pangu_alpha/src/adam.py index 4488080e0cbdfea3acc0128c655cf81fa243da23..5f06c66bd88b3cb224478c4b8ef41e04e3540889 100644 --- a/official/nlp/Pangu_alpha/src/adam.py +++ b/official/nlp/Pangu_alpha/src/adam.py @@ -39,13 +39,15 @@ def _update_run_kernel(opt, clip_value, beta1, beta2, eps, lr, weight_decay, Update parameters by AdamWeightDecay op. """ success = True + cast = P.Cast() + cast.add_prim_attr("primitive_target", "CPU") if optim_filter: if decay_flags: next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, - P.Cast()(gradient, mstype.float16), clip_value) + cast(gradient, mstype.float16), clip_value) else: next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, - P.Cast()(gradient, mstype.float16), clip_value) + cast(gradient, mstype.float16), clip_value) return F.depend(success, next_param) return success