diff --git a/msit/components/convert/model_convert/aie/cpp/aie_convert.cpp b/msit/components/convert/model_convert/aie/cpp/aie_convert.cpp index d762384cd7f6ba6f1f2fbb749e419c35194f41ca..71bf36192542c46624e8e9d152d0dc3fbc5bd1b6 100644 --- a/msit/components/convert/model_convert/aie/cpp/aie_convert.cpp +++ b/msit/components/convert/model_convert/aie/cpp/aie_convert.cpp @@ -49,10 +49,27 @@ int main(int argc, char** argv) auto modelData = builder->BuildModel(network, config); std::ofstream fout(outputPath, std::ios::binary); + // 检查打开文件是否成功 + if (!fout) { + throw std::runtime_error("Failed to open: " + outputPath); + } fout.write((char*)modelData.data.get(), modelData.size); + // 检查写入是否成功,否则清除残留文件 + if (!fout) { + fout.close(); + std::error_code ignore; + std::filesystem::remove(outputPath, ignore); + throw std::runtime_error("Failed to write to: " + outputPath); + } fout.close(); + // 检查关闭文件是否成功 + if (!fout) { + std::error_code ignore; + std::filesystem::remove(outputPath, ignore); + throw std::runtime_error("Failed to close file: " + outputPath); + } std::cout << "AIE Model Convert Succeed" << std::endl; diff --git a/msit/components/debug/compare/msquickcmp/save_om_model/export_om_model.cpp b/msit/components/debug/compare/msquickcmp/save_om_model/export_om_model.cpp index 292dde830c10d3a4779c371bc467e09daba4322b..2021416bce753908aa7f3151dd4aaa25618cd795 100644 --- a/msit/components/debug/compare/msquickcmp/save_om_model/export_om_model.cpp +++ b/msit/components/debug/compare/msquickcmp/save_om_model/export_om_model.cpp @@ -76,7 +76,16 @@ static std::string GetFullPath(const std::string &originPath) return originPath; } - std::string cwd = getcwd(nullptr, 0); + // 获取当前工作目录,检查是否成功 + char* cwd_ptr = getcwd(nullptr, 0); + if (!cwd_ptr) { // 检查getcwd返回值是否为nullptr + // 抛出异常 + throw std::runtime_error("Failed to get current working directory"); + } + + // 安全处理动态分配的内存 + std::string cwd(cwd_ptr); + free(cwd_ptr) // 释放getcwd分配的内存,避免泄露 return std::move(cwd + PATH_SEPARATOR + originPath); } diff --git a/msit/components/debug/compare/msquickcmp/tf/tf_dump_data.py b/msit/components/debug/compare/msquickcmp/tf/tf_dump_data.py index ce9efd022eb2039ff712ad3513369e347a9e5a25..25125d844b948fccb88dbfd613f30ff9ffb43ee6 100644 --- a/msit/components/debug/compare/msquickcmp/tf/tf_dump_data.py +++ b/msit/components/debug/compare/msquickcmp/tf/tf_dump_data.py @@ -31,7 +31,7 @@ from msquickcmp.common.utils import AccuracyCompareException from components.utils.file_open_check import ms_open from components.utils.constants import TENSOR_MAX_SIZE -from components.utils.util import load_file_to_read_common_check +from components.utils.util import load_file_to_read_common_check, check_str_for_cmd class TfDumpData(DumpData): @@ -193,6 +193,7 @@ class TfDumpData(DumpData): # get the net_output dump file info if tensor_name in self.net_output_name: self.net_output[self.net_output_name.index(tensor_name)] = npy_file_path + check_str_for_cmd(tensor_name, 'tensor_name') pt_command_list.append("pt %s -n %d -w %s" % (tensor_name, count_tensor_name, npy_file_path)) return pt_command_list diff --git a/msit/components/utils/util.py b/msit/components/utils/util.py index 69fe65ff63141064322df6ae2735f25fc9700e21..5153a4ac34490670343590d73bf9175ff2d786ef 100644 --- a/msit/components/utils/util.py +++ b/msit/components/utils/util.py @@ -168,6 +168,16 @@ def filter_cmd(paras): return filtered +def check_str_for_cmd(strings, var_name): + whitelist_pattern = re.compile(r"^[a-zA-Z0-9_\-./=:,\[\] ]+$") + if whitelist_pattern.fullmatch(strings): + pass + else: + raise ValueError( + f'{var_name} contains invalid characters. Only the "{whitelist_pattern}" pattern is allowed.' + ) + + def load_file_to_read_common_check_for_cli(value, exts=None): try: value = load_file_to_read_common_check(value, exts)