From 2e1bdea620c90cdb49127695c33cdee033f2d370 Mon Sep 17 00:00:00 2001 From: chenyuxing <2818499974@qq.com> Date: Thu, 7 Aug 2025 10:15:07 +0800 Subject: [PATCH] check whether num_layers equals sum of num_layer_list --- .../msprof_analyze/tinker/utils/config.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/profiler/msprof_analyze/tinker/utils/config.py b/profiler/msprof_analyze/tinker/utils/config.py index 3d958c4dd..0c7bb0c4c 100644 --- a/profiler/msprof_analyze/tinker/utils/config.py +++ b/profiler/msprof_analyze/tinker/utils/config.py @@ -17,6 +17,7 @@ import os import argparse from functools import partial +import json from tinker.utils.utils import extract_arg_value_from_json, check_path_exist, \ check_path_before_create, check_files_in_dir, check_file_suffix, project_root @@ -178,6 +179,27 @@ def process_path(args): args.profiled_data_path = os.path.join(project_dir, 'profiled_data', args.profiled_data_path) +def check_layers(args): + if args.mode == 'simulate': + data_path = args.profiled_data_path + for filename in os.listdir(data_path): + # 检查文件名称是否是以“model_info”开头的json文件 + if filename.startswith("model_info") and filename.endswith("json"): + # 构建完整的文件路径 + file_path = os.path.join(data_path, filename) + + # 读取文件内容 + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + # 提取num_layers的值 + num_layers = data.get("num_layers") + parts = args.num_layer_list.split(',') + # 把每个部分转成整数 + int_list = [int(parts) for part in parts] + if num_layers != sum(int_list): + raise ValueError("sum of num_layer_list should be equal to num_layers") + + def check_args(args: argparse.Namespace) -> argparse.Namespace: """参数校验""" @@ -219,6 +241,7 @@ def check_args(args: argparse.Namespace) -> argparse.Namespace: check_args_none(args) process_path(args) + check_layers(args) check_path_valid(args.mode) check_post_train(args) -- Gitee