diff --git a/mindspeed_ms/tools/converter/combine_ckpt_dp_zero.py b/mindspeed_ms/tools/converter/combine_ckpt_dp_zero.py index 4912239950384ed859071adc1052e4cbec0b14c5..bcd6fd5da9f22d3a2260dc0b729a70179280b7af 100644 --- a/mindspeed_ms/tools/converter/combine_ckpt_dp_zero.py +++ b/mindspeed_ms/tools/converter/combine_ckpt_dp_zero.py @@ -77,7 +77,7 @@ def combine_zero3_data(param_total_dict, param_name, no_save_optim): return param_data def check_key(key): - key_list = ['learning_rate', 'weight_decay', 'epoch', 'state', 'default_generator'] + key_list = ['learning_rate', 'weight_decay', 'epoch', 'step', 'default_generator'] if any(x in key for x in key_list): return True return False