diff --git a/mindspeed_mm/configs/config.py b/mindspeed_mm/configs/config.py index 6b4f915fe37beced91f05ee51b9b83e8c7522b4a..cb8f30dac198f51bf12588096fb2c1160a9239f4 100644 --- a/mindspeed_mm/configs/config.py +++ b/mindspeed_mm/configs/config.py @@ -1,15 +1,17 @@ import os import json +from functools import wraps from mindspeed_mm.utils.utils import get_dtype class ConfigReader: - """ + """ read_config read json file dict processed by MMconfig and convert to class attributes, besides, read_config support to convert dict for specific purposes. """ + def __init__(self, config_dict: dict) -> None: for k, v in config_dict.items(): if k == "dtype": @@ -18,7 +20,7 @@ class ConfigReader: self.__dict__[k] = ConfigReader(v) else: self.__dict__[k] = v - + def to_dict(self) -> dict: ret = {} for k, v in self.__dict__.items(): @@ -27,7 +29,7 @@ class ConfigReader: else: ret[k] = v return ret - + def __repr__(self) -> str: for k, v in self.__dict__.items(): if isinstance(v, self.__class__): @@ -58,25 +60,28 @@ class ConfigReader: class MMConfig: - """ - MMconfig + """ + MMconfig input: a dict of json path """ + def __init__(self, json_files: dict) -> None: for json_name, json_path in json_files.items(): if os.path.exists(json_path): real_path = os.path.realpath(json_path) config_dict = self.read_json(real_path) setattr(self, json_name, ConfigReader(config_dict)) - + else: + raise Exception("{} don't exist".format(json_name)) + @staticmethod def read_json(json_path): with open(json_path, mode="r") as f: json_file = f.read() config_dict = json.loads(json_file) return config_dict - - + + def _add_mm_args(parser): group = parser.add_argument_group(title="multimodel") group.add_argument("--mm-data", type=str, default="") @@ -90,9 +95,189 @@ def mm_extra_args_provider(parser): return parser +def merge_mm_args_decorator(func): + called = False + + @wraps(func) + def wrapper(args): + func(args) + nonlocal called + if not called: + args_external_path_checker(args) + called = True + return wrapper + + +@merge_mm_args_decorator def merge_mm_args(args): if not hasattr(args, "mm"): setattr(args, "mm", object) json_files = {"model": args.mm_model, "data": args.mm_data, "tool": args.mm_tool} args.mm = MMConfig(json_files) + +def args_external_path_checker(args): + """ + Verify the security of all file path parameters in 3 code repositories:mindspeed-mm,mindspeed,megatron + and 3 json file:mm_data.json,mm_model.json,mm_tool.json + """ + # args from mindspeed_mm + mindspeed_mm_params = ['load_base_model', "mm_data", "mm_tool", "mm_model"] + for param in mindspeed_mm_params: + if hasattr(args, param) and getattr(args, param): + file_legality_checker(getattr(args, param), param) + + # args from mindspeed + mindspeed_param = ['auto_tuning_work_dir', "profile_save_path", "tokenizer_name_or_path", "additional_config", + "layerzero_config", "prof_file"] + for param in mindspeed_param: + if hasattr(args, param) and getattr(args, param): + file_legality_checker(getattr(args, param), param) + + # args from megatron + megatron_param = ["tensorboard_dir", "save", "load", "pretrained_checkpoint", "data_cache_path", "merge_file", + "s3_cache_path", "ict_load", "bert_load", "titles_data_path", "evidence_data_path", + "block_data_path", "embedding_path", "yaml_cfg"] + for param in megatron_param: + if hasattr(args, param) and getattr(args, param): + file_legality_checker(getattr(args, param), param) + + # These parameters may have the following format:weight path weight path + megatron_special_params = ["data_path", "train_data_path", "valid_data_path", "test_data_path"] + for param in megatron_special_params: + if hasattr(args, param) and getattr(args, param): + file_list = split_param(param) + for path in file_list: + file_legality_checker(path, param) + + # arge from MM_ModeL + MM_model_params = ["text_encoder.from_pretrained", "text_enceder.template_file_path", "text_encoder.ckpt_path", + "image_encoder.vision_encoder.ckpt_path", "image_encoder.vision_projector.ckpt_path", + "ae.from_pretrained", "ae.from_pretrained_3dvae_ckpt", "as.i2v_processor.processor_path", + "ae.i2v_processor.image_eneoder", "dpo.histgram_path", + "tokenizer.from_pretrained", "tokenizer.template_file_path", "predictor.from_pretrained", + "discriminator.perceptual_from_pretrained", "save_path", "prompt", "video", "image", + "image_path", "file_path", "image_processer_path", "from_pretrained", + "conditional_pixel_values_path", "ckpt_path", "eval_config.dataset.basic_parm.data_path", + "eval_config.dataset.basic_parm.data_folder", "eval_config.dataset.extra_param.prompt_file", + "eval_config.dataset.extra_param.augmented_prompt_file", "eval_config.eval_result_path", + "eval_config.image_path", "eval_config.long_eval_config", + "result_output_path", "dataset_path", "evaluation_dataset"] + if args.mm.model: + for param in MM_model_params: + values = get_ConfigReader_value(args.mm.model, param) + for value in values: + if value: + file_legality_checker(value, param) + + # args from MM_Data + MM_data_params = ["dataset_param.basic_parameters.data_path", "dataset_param.basic_parameters.data_folder", + "dataset_param.basic_parameters.dataset_dir", "dataset_param.basic_parameters.dataset", + "dataset_param.basic_parameters.cache_dir", "dataset_param.tokenizer_config.from_pretrained", + "dataset_param.tokenizer_config.template_file_path", "dataset_param.processor_path", + "dataset_param.preprocess_parameters.model_name_or_path", + "dataset_param.preprocess_parameters.processor_name_or_path", + "dataset_param.video_folder", "dataloader_param.collate_param.processor_name_or_path"] + if args.mm.data: + for param in MM_data_params: + values = get_ConfigReader_value(args.mm.data, param) + for value in values: + if value: + file_legality_checker(value, param) + + # args from MM_Tool + MM_tool_params = ["profile.static_param.save_path", "profile.static_param.dynamic_param", + "memory_profile.save_path", "sorafeature.save_path"] + if args.mm.tool: + for param in MM_tool_params: + values = get_ConfigReader_value(args.mm.tool, param) + for value in values: + if value: + file_legality_checker(value, param) + + +def file_legality_checker(file_path, param_name, base_dir=None): + """ + Perform soft link and path traversal checks on file path + """ + if not base_dir: + base_dir = os.getcwd() + + # check file exist + try: + if not os.path.exists(file_path): + return False + except OSError: + return False + + # check symbolic link + from mindspeed_mm.utils.security_utils.validate_path import normalize_path + try: + norm_path, is_link = normalize_path(file_path) + if is_link: + print( + "WARNING: [{}] {} is a symbolic link.It's normalize path is {}".format(param_name, file_path, + norm_path)) + return False + except OSError: + return False + + # check path crossing + try: + # get absolute file path + norm_path = os.path.realpath(file_path) + # get absolute base dir path + base_directory = os.path.abspath(base_dir) + if not norm_path.startswith(base_directory): + print("WARNING: [{}] {} attempts to traverse to an disallowed directory".format(param_name, file_path)) + return False + except OSError: + return False + + return True + + +def split_param(param): + """ + Segment some special parameters in megatron + """ + + def is_number(s): + if isinstance(s, str): + s = s.strip() + try: + float(s) + return True + except (ValueError, TypeError): + return False + + param_list = param.split(" ") + if len(param_list) == 1: + return param_list + else: + if is_number(param_list[0]): + return [param_list[2 * i] for i in range(len(param_list) // 2)] + else: + return param_list + + +def get_ConfigReader_value(config, param): + objs = [config.to_dict()] + for key in param.split("."): + new_objs = [] + for obj in objs: + if not obj: + continue + if key in obj: + if not obj[key]: + continue + if isinstance(obj[key], list): + new_objs.extend(obj[key]) + else: + new_objs.append(obj[key]) + if new_objs: + objs = new_objs + else: + return [] + + return objs