From b853b422c49043ccaeb7ee0293891696fef638d9 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 14 Aug 2025 17:06:09 +0800 Subject: [PATCH] =?UTF-8?q?[=E5=AE=89=E5=85=A8=E6=95=B4=E6=94=B9]=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E8=B7=AF=E5=BE=84=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/openmind/archived/cli_legacy/model_cli.py | 28 +- .../archived/cli_legacy/pipeline_cli.py | 13 +- src/openmind/archived/pipelines/common/hf.py | 15 +- src/openmind/flow/arguments.py | 8 +- src/openmind/flow/datasets/loader.py | 8 + src/openmind/flow/model/loader.py | 3 +- src/openmind/integrations/datasets.py | 3 + src/openmind/utils/arguments_utils.py | 280 ++++++++++++++++- src/openmind/utils/hub.py | 13 +- tests/unit/cli/test_push.py | 77 +++-- tests/unit/utils/test_path_check.py | 296 ++++++++++++++++++ 11 files changed, 685 insertions(+), 59 deletions(-) create mode 100644 tests/unit/utils/test_path_check.py diff --git a/src/openmind/archived/cli_legacy/model_cli.py b/src/openmind/archived/cli_legacy/model_cli.py index 85b3857..09765a9 100644 --- a/src/openmind/archived/cli_legacy/model_cli.py +++ b/src/openmind/archived/cli_legacy/model_cli.py @@ -34,7 +34,13 @@ from openmind.utils.constants import ( SNAPSHOTS, MODEL_CONFIG, ) -from openmind.utils.arguments_utils import str2bool, str2bool_or_auto +from openmind.utils.arguments_utils import ( + str2bool, + str2bool_or_auto, + validate_directory, + validate_file_path, + validate_cache_dir, +) logger = logging.get_logger(__name__) @@ -212,6 +218,15 @@ def try_to_trans_to_list(patterns): def validata_args(args): + if hasattr(args, "local_dir") and args.local_dir: + validate_directory(args.local_dir) + + if hasattr(args, "cache_dir") and args.cache_dir: + validate_cache_dir(args.cache_dir) + + if hasattr(args, "folder_path") and args.folder_path: + validate_directory(args.folder_path) + if args.allow_patterns: args.allow_patterns = try_to_trans_to_list(args.allow_patterns) @@ -326,6 +341,7 @@ def _check_file_exists(directory, filename): if os.path.islink(target_file): target_path = os.readlink(target_file) absolute_target_path = os.path.realpath(os.path.join(os.path.dirname(target_file), target_path)) + validate_file_path(absolute_target_path) if os.path.exists(absolute_target_path): return True elif filename in files: @@ -352,6 +368,7 @@ def _check_git_om_model(model_path): git_head_path = os.path.join(model_path, GIT_LOGS_HEAD) if os.path.exists(git_head_path): + validate_file_path(git_head_path) with open(git_head_path, "r") as f: git_log_info = f.read().split()[-1] if OPENMIND_PREFIX in git_log_info: @@ -379,6 +396,7 @@ def _check_cache_om_model(model_path): git_head_path = os.path.join(model_path, GIT_LOGS_HEAD) if os.path.isdir(model_path): + validate_directory(model_path) for file in os.listdir(model_path): if file == SNAPSHOTS: model_cache_path = os.path.join(model_path, file) @@ -433,10 +451,12 @@ def _get_model_info(args: argparse.Namespace) -> set: if args.local_dir: local_path = Path(args.local_dir).absolute() + validate_directory(str(local_path)) _add_model_info(local_path, model_info) if args.cache_dir: cache_path = Path(args.cache_dir).absolute() + validate_directory(str(cache_path)) _add_model_info(cache_path, model_info) if not args.local_dir and not args.cache_dir: @@ -468,8 +488,10 @@ def run_rm(): for model in model_info: if args.repo_id == model[0]: - print(f"Deleted file path: {model[1]}") - shutil.rmtree(model[1]) + model_path = model[1] + validate_directory(model_path) + print(f"Deleted file path: {model_path}") + shutil.rmtree(model_path) delete_num += 1 if delete_num > 1: print("Files deleted successfully.") diff --git a/src/openmind/archived/cli_legacy/pipeline_cli.py b/src/openmind/archived/cli_legacy/pipeline_cli.py index 995f7cb..1c0ecd4 100644 --- a/src/openmind/archived/cli_legacy/pipeline_cli.py +++ b/src/openmind/archived/cli_legacy/pipeline_cli.py @@ -16,10 +16,10 @@ import ast import sys import tempfile -import yaml - +import os +from openmind.utils.arguments_utils import validate_file_path, validate_image_path, validate_input_or_path from openmind.utils import is_vision_available, is_ms_available -from openmind.utils.arguments_utils import _trans_args_list_to_dict +from openmind.utils.arguments_utils import _trans_args_list_to_dict, safe_load_yaml, validate_directory from openmind import pipeline from openmind.archived.pipelines.pipeline_utils import SUPPORTED_TASK_MAPPING, get_task_from_readme from openmind.archived.pipelines.builder import _parse_native_json @@ -58,8 +58,7 @@ def parse_args(): if known_args.yaml_path is not None: yaml_path = known_args.yaml_path - with open(yaml_path, "r") as f: - yaml_all_args = yaml.safe_load(f) + yaml_all_args = safe_load_yaml(yaml_path) defined_params = {action.dest: action for action in parser._actions} yaml_args = {} @@ -134,6 +133,9 @@ def try_to_trans_to_dict(input_or_path): def _init_pipeline(**kwargs): + model_path = kwargs.get("model", None) + if model_path and os.path.exists(model_path): + validate_directory(model_path) return pipeline(**kwargs) @@ -157,6 +159,7 @@ def _extract_params(args): def _run_cmd_without_docker(params) -> None: input_or_path = params.pop("input") + validate_input_or_path(input_or_path) is_pt_framework = params.get("framework", None) is None or params.get("framework") == "pt" if is_pt_framework and params.get("device", None) is None and params.get("device_map", None) is None: params["device"] = "npu:0" diff --git a/src/openmind/archived/pipelines/common/hf.py b/src/openmind/archived/pipelines/common/hf.py index c2114c2..db0cea7 100644 --- a/src/openmind/archived/pipelines/common/hf.py +++ b/src/openmind/archived/pipelines/common/hf.py @@ -25,6 +25,7 @@ import torch from openmind.utils import is_vision_available from openmind.utils.constants import Backends, Tasks from openmind.utils.logging import get_logger, set_verbosity_info +from openmind.utils.arguments_utils import validate_image_path_list from ..base import PTBasePipeline from .hf_utils import PIPELINE_CREATOR_MAPPING @@ -307,7 +308,7 @@ class VisualQuestionAnsweringPipeline(HFPipeline): - **label** (`str`) -- The label identified by the model. - **score** (`int`) -- The score attributed by the model for that label. """ - + validate_image_path_list(image) return self.pipeline( image=image, question=question, @@ -373,7 +374,7 @@ class ZeroShotObjectDetectionPipeline(HFPipeline): - **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys. """ - + validate_image_path_list(image) return self.pipeline( image=image, candidate_labels=candidate_labels, @@ -483,6 +484,7 @@ class DepthEstimationPipeline(HFPipeline): - **predicted_depth** (`torch.Tensor`) -- The predicted depth by the model as a `torch.Tensor`. - **depth** (`PIL.Image`) -- The predicted depth by the model as a `Image`. """ + validate_image_path_list(inputs) return self.pipeline( inputs, **kwargs, @@ -543,6 +545,7 @@ class ImageToImagePipeline(HFPipeline): single image, the return will be also a single image, if the input is a list of several images, it will return a list of transformed images. """ + validate_image_path_list(images) return self.pipeline( images=images, **kwargs, @@ -604,7 +607,7 @@ class MaskGenerationPipeline(HFPipeline): the "object" described by the label and the mask. """ - + validate_image_path_list(image) return self.pipeline( image=image, *args, @@ -654,7 +657,7 @@ class ZeroShotImageClassificationPipeline(HFPipeline): - **score** (`float`) -- The score attributed by the model to that label. It is a value between 0 and 1, computed as the `softmax` of `logits_per_image`. """ - + validate_image_path_list(image) return self.pipeline( image, **kwargs, @@ -745,7 +748,7 @@ class ImageClassificationPipeline(HFPipeline): - **label** (`str`) -- The label identified by the model. - **score** (`int`) -- The score attributed by the model for that label. """ - + validate_image_path_list(inputs) return self.pipeline( inputs, **kwargs, @@ -794,7 +797,7 @@ class ImageToTextPipeline(HFPipeline): - **generated_text** (`str`) -- The generated text. """ - + validate_image_path_list(inputs) return self.pipeline( inputs, **kwargs, diff --git a/src/openmind/flow/arguments.py b/src/openmind/flow/arguments.py index 9013e1e..5e02839 100644 --- a/src/openmind/flow/arguments.py +++ b/src/openmind/flow/arguments.py @@ -18,11 +18,10 @@ import importlib import importlib.metadata import re -import yaml from openmind.utils.constants import Stages, FinetuneType, Frameworks from openmind.utils.import_utils import is_swanlab_available -from openmind.utils.arguments_utils import str2bool +from openmind.utils.arguments_utils import str2bool, validate_directory from openmind.utils import logging, is_transformers_available, is_torch_available, is_trl_available from openmind.flow.legacy_arguments import _add_legacy_args, _migrate_legacy_args @@ -105,9 +104,9 @@ def parse_args(yaml_path=None, ignore_unknown_args=False, custom_args=None): def parse_yaml_file(parser, yaml_path): """Parse and check yaml arguments""" + from openmind.utils.arguments_utils import safe_load_yaml - with open(yaml_path, "r") as f: - yaml_args = yaml.safe_load(f) + yaml_args = safe_load_yaml(yaml_path) defined_params = {action.dest for action in parser._actions} known_args = {k: v for k, v in yaml_args.items() if k in defined_params} @@ -151,6 +150,7 @@ def validate_args(args): # Detecting last checkpoint if os.path.isdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: + validate_directory(args.output_dir) last_checkpoint = get_last_checkpoint(args.output_dir) if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0: raise ValueError( diff --git a/src/openmind/flow/datasets/loader.py b/src/openmind/flow/datasets/loader.py index fb46e69..e83a964 100644 --- a/src/openmind/flow/datasets/loader.py +++ b/src/openmind/flow/datasets/loader.py @@ -90,6 +90,14 @@ def _get_merged_datasets(dataset_names: Optional[str]): custom_dataset_path.strip(" ") for custom_dataset_path in args.custom_dataset_info.split(",") ] for custom_dataset_path in custom_dataset_path_list: + # check custom_dataset_path + from openmind.utils.arguments_utils import validate_file_path + + try: + # check custom_dataset_path + validate_file_path(custom_dataset_path, allowed_extensions=[".json"]) + except Exception as e: + raise ValueError(f"Invalid custom dataset info path: {custom_dataset_path}. Error: {str(e)}") with open(custom_dataset_path, "r") as f: custom_dataset_info = json.load(f) dataset_info.update(custom_dataset_info) diff --git a/src/openmind/flow/model/loader.py b/src/openmind/flow/model/loader.py index a6fe3b5..843a37e 100644 --- a/src/openmind/flow/model/loader.py +++ b/src/openmind/flow/model/loader.py @@ -44,7 +44,7 @@ from openmind.flow.model.adapter import apply_adapter from openmind.flow.model.sequence_parallel.seq_utils import apply_sequence_parallel from openmind.integrations.transformers.bitsandbytes import patch_bnb from openmind.utils.loader_utils import get_platform_loader -from openmind.utils.arguments_utils import print_formatted_table +from openmind.utils.arguments_utils import print_formatted_table, validate_directory logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -82,6 +82,7 @@ def try_download_from_hub() -> str: ) from e if os.path.exists(args.model_name_or_path): + validate_directory(args.model_name_or_path) return args.model_name_or_path return snapshot_download_func( diff --git a/src/openmind/integrations/datasets.py b/src/openmind/integrations/datasets.py index caebe33..a2d5aea 100644 --- a/src/openmind/integrations/datasets.py +++ b/src/openmind/integrations/datasets.py @@ -762,6 +762,9 @@ def load_dataset( raise ValueError("The path should be in the form of `namespace/datasetname` or local path") elif is_local_path: logger.info("Using local dataset") + from openmind.utils.arguments_utils import validate_directory + + validate_directory(path) else: try: openmind_hub.repo_info(repo_id=path, repo_type="dataset", token=token) diff --git a/src/openmind/utils/arguments_utils.py b/src/openmind/utils/arguments_utils.py index adb7317..ffa4b9a 100644 --- a/src/openmind/utils/arguments_utils.py +++ b/src/openmind/utils/arguments_utils.py @@ -18,8 +18,17 @@ import argparse import os import yaml - +from openmind.utils import logging from tabulate import tabulate +from typing import List, Union +import re +from openmind.utils import is_vision_available + +if is_vision_available(): + from PIL.Image import Image + + +logger = logging.get_logger(__name__) def _trans_args_list_to_dict(args_list: list) -> dict: @@ -70,26 +79,277 @@ def str2bool_or_auto(value): raise argparse.ArgumentTypeError("Value should be one of: 'true', 'false', or 'auto'.") +def validate_directory(path, allow_symlinks=True): + """ + Validate directory path with realpath resolution + """ + if not isinstance(path, str) or not path.strip(): + raise ValueError("Directory path must be a non-empty string") + + # Normalize and resolve real path + normalized_path = os.path.normpath(path.strip()) + real_path = os.path.realpath(normalized_path) + + # Symlink policy + if not allow_symlinks: + if os.path.islink(normalized_path) or real_path != normalized_path: + raise ValueError(f"Symbolic link not allowed: {path} -> {real_path}") + elif os.path.islink(normalized_path): + logger.warning(f"Directory path is a symbolic link: {normalized_path} -> {real_path}") + + # Check if the resolved path is a directory + if not os.path.isdir(real_path): + raise ValueError(f"Path is not a directory or not exists: {real_path}") + + # Character validation on the resolved path + if os.name == "nt": # Windows + invalid_chars = r'[<>"|?*\x00-\x1f]' + if re.search(invalid_chars, real_path): + raise ValueError(f"Directory path contains invalid characters: {real_path}") + + # Component pattern check + if os.name == "nt": # Windows + # Windows paths can contain backslashes and drive letters + pattern = r"^([A-Za-z]:)?([\\\/]|\.|[0-9a-zA-Z_\-\s\.\(\)\[\]\{\}~\u4e00-\u9fa5])*$" + else: # Unix-like systems + pattern = r"(\.|/|_|-|\s|[~0-9a-zA-Z]|[\u4e00-\u9fa5])+" + if not re.fullmatch(pattern, real_path): + raise RuntimeError(f"Invalid directory path: {real_path}") + + return real_path + + +def validate_cache_dir(cache_dir): + """ + Function specifically for validating cache directories, allows non-existent paths + """ + if not isinstance(cache_dir, str): + raise ValueError(f"Cache directory must be a string, got {type(cache_dir)}") + + # For cache directories, we only need to validate the legality of the path format, + # not requiring the directory to exist + normalized_path = os.path.normpath(cache_dir.strip()) + + # Basic path format validation + if os.name == "nt": # Windows + pattern = r"^([A-Za-z]:)?([\\\/]|\.|[0-9a-zA-Z_\-\s\.\(\)\[\]\{\}~\u4e00-\u9fa5])*$" + else: # Unix-like systems + pattern = r"^(/|\./?|[~0-9a-zA-Z_\-\s\./\u4e00-\u9fa5])+$" + + if not re.fullmatch(pattern, normalized_path): + raise ValueError(f"Invalid cache directory path format: {cache_dir}") + + return normalized_path + + +def _validate_file_path_common( + file_path: str, + allowed_extensions: list = None, + allow_symlinks=True, + require_read=True, + require_write=True, +): + """ + Universal file path validation function + + Args: + file_path: File path to validate + allowed_extensions: List of allowed file extensions + allow_symlinks: Whether to allow symbolic links + require_read: Whether to require read access to the file + require_write: Whether to require write access to the file + component_pattern: Regular expression pattern to validate file path components + Returns: + str: Normalized file path + """ + normalized_path = os.path.normpath(file_path) + real_path = os.path.realpath(normalized_path) + + # Symlink policy: if disallowed, reject when resolved path differs or original (or any component) is a symlink + if not allow_symlinks: + if os.path.islink(normalized_path) or real_path != normalized_path: + raise ValueError(f"Symbolic link not allowed: {file_path} -> {real_path}") + + # Must be an existing regular file + if not os.path.isfile(real_path): + # isfile implicitly checks existence + raise FileNotFoundError(f"File not found or not a regular file: {file_path}") + + # Extension whitelist + if allowed_extensions: + normalized_extensions = [ext if ext.startswith(".") else f".{ext}" for ext in allowed_extensions] + if not any(real_path.lower().endswith(ext.lower()) for ext in normalized_extensions): + raise ValueError(f"File must have one of extensions: {normalized_extensions}") + + # Permissions + if require_read and not os.access(real_path, os.R_OK): + raise PermissionError(f"No read permission: {real_path}") + if require_write and not os.access(real_path, os.W_OK): + raise PermissionError(f"No write permission: {real_path}") + + if os.name == "nt": # Windows + # Windows paths can contain backslashes and drive letters + pattern = r"^([A-Za-z]:)?([\\\/]|\.|[0-9a-zA-Z_\-\s\.\(\)\[\]\{\}~\u4e00-\u9fa5])*$" + else: # Unix-like systems + pattern = r"(\.|/|_|-|\s|[~0-9a-zA-Z]|[\u4e00-\u9fa5])+" + if not re.fullmatch(pattern, real_path): + raise RuntimeError(f"Invalid input path: {real_path}") + return real_path + + def safe_load_yaml(path): if path is None: - raise ValueError("param `path` is required for `safe_load_yaml` func.") + raise ValueError("param `path` is required for [safe_load_yaml]") if not isinstance(path, str): raise TypeError(f"param `path` should be string format for `safe_load_yaml func`, but got {type(path)} type.") - if not os.path.exists(path): - raise FileNotFoundError(f"yaml file path {path} does not exist.") - - path = os.path.realpath(path) - - if not path.endswith(".yaml") and not path.endswith(".yml"): - raise ValueError(f"path {path} is not a yaml/yml file path.") + normalized_path = _validate_file_path_common(path, allowed_extensions=[".yaml", ".yml"], require_read=True) - with open(path, "r") as file: + with open(normalized_path, "r") as file: content = yaml.safe_load(file) return content +def validate_file_path( + file_path, + allowed_extensions: list = None, + allow_symlinks: bool = True, + require_read: bool = True, + require_write: bool = True, +): + """ + Validate the security of a file path + + Args: + file_path: File path to validate + allowed_extensions: List of allowed file extensions + allow_symlinks: Whether to allow symbolic links + """ + # Directly call the common function + return _validate_file_path_common(file_path, allowed_extensions, allow_symlinks, require_read, require_write) + + +def validate_image_path(image_path): + """Validate the security of image paths""" + if not isinstance(image_path, str): + # If it's not a string (e.g. already a PIL image object), return directly + return image_path + + if ( + os.path.isabs(image_path) + or image_path.startswith("./") + or image_path.startswith("../") + or os.path.exists(image_path) + ): + return _validate_file_path_common(image_path, require_read=True) + + # If it's a URL, perform URL validation + return validate_url(image_path) + + +def validate_image_path_list(image: Union[str, List[str], "Image", List["Image"]]): + # Perform security validation on image input + if isinstance(image, str): + validate_image_path(image) + elif isinstance(image, list): + # If it's a list, validate each element + validated_images = [] + for img in image: + validate_image_path(img) + validated_images.append(img) + image = validated_images + return image + + +def validate_url(url: str) -> str: + """ + Perform only character-level legality checks for a URL-like string. + This does NOT validate structure (netloc, path semantics, reachability). + Checks: + - non-empty str + - no control / whitespace chars + - scheme (if present) matches ^[A-Za-z][A-Za-z0-9+.-]*$ + - percent-encoding tokens are well-formed (every '%' followed by two hex digits) + - all chars belong to a conservative allowed set (RFC 3986 unreserved + reserved) + Returns the original string if all checks pass; raises ValueError otherwise. + """ + if not isinstance(url, str) or not url: + raise ValueError("Invalid URL: must be a non-empty string") + + # Reject control chars and spaces + if re.search(r"[\x00-\x20\x7F]", url): + raise ValueError(f"Invalid URL: contains control or whitespace characters: {url!r}") + + # Percent-encoding correctness + if re.search(r"%(?![0-9A-Fa-f]{2})", url): + raise ValueError(f"Invalid URL: malformed percent-encoding: {url!r}") + + # Optional scheme check (only if a scheme appears) + if ":" in url: + scheme = url.split(":", 1)[0] + if scheme and not re.fullmatch(r"[A-Za-z][A-Za-z0-9+.\-]*", scheme): + raise ValueError(f"Invalid URL: bad scheme: {scheme!r}") + + if url.endswith(":") or url.endswith("://"): + raise ValueError(f"Invalid URL: incomplete URL structure: {url!r}") + + # Allowed characters (unreserved + reserved + percent-escapes + '#') + # unreserved: A-Z a-z 0-9 - . _ ~ + # reserved: : / ? # [ ] @ ! $ & ' ( ) * + , ; = + allowed_pattern = r"^[A-Za-z0-9\-._~:/?#\[\]@!$&\'()*+,;=%]*$" + if not re.fullmatch(allowed_pattern, url): + raise ValueError(f"Invalid URL: contains disallowed characters: {url!r}") + + return url + + +def validate_input_or_path(input_or_path): + """ + Validate input parameter, which can be a URL, local file path, or string + + Args: + input_or_path: Input parameter, could be a URL, file path, or plain string + + Returns: + Validated input parameter + + Raises: + ValueError: Raised when the input parameter is invalid + """ + if not isinstance(input_or_path, (str, dict)): + raise ValueError(f"Input must be a string or dict, got {type(input_or_path)}") + + # If it's a dict, return directly + if isinstance(input_or_path, dict): + return input_or_path + + # If it's a string, perform further validation + if isinstance(input_or_path, str): + # First check if it's a valid URL + try: + validated_url = validate_url(input_or_path) + return validated_url + except ValueError: + # Not a valid URL, continue to check if it's a local file + pass + + # Check if it's a local file or directory path + if os.path.exists(input_or_path): + # If it's a directory, validate it as a directory + if os.path.isdir(input_or_path): + return validate_directory(input_or_path) + # If it's a file, validate it as a file + else: + return validate_file_path(input_or_path) + + # If it's neither a URL nor an existing file, treat it as a plain string input (like a prompt) + # In this case, no special validation is needed, return directly + return input_or_path + + return input_or_path + + def print_formatted_table(data, header, missingval="N/A"): print(tabulate(data, header, missingval=missingval, tablefmt="fancy_grid")) diff --git a/src/openmind/utils/hub.py b/src/openmind/utils/hub.py index fb776ed..b9d1093 100644 --- a/src/openmind/utils/hub.py +++ b/src/openmind/utils/hub.py @@ -23,6 +23,8 @@ import tempfile from typing import Dict, Optional, Union from uuid import uuid4 import warnings +from .arguments_utils import validate_directory, validate_file_path, validate_url + from openmind_hub import ( _CACHED_NO_EXIST, @@ -211,6 +213,8 @@ class OpenMindHub(BaseHub): FutureWarning, ) + validate_url(url) + dir_name = os.path.join(OM_HOME, "tmp_files_from_url") os.makedirs(dir_name, exist_ok=True) @@ -266,6 +270,8 @@ class OpenMindHub(BaseHub): """ if not os.path.isfile(index_filename): raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + if os.path.isfile(index_filename): + validate_file_path(index_filename) with open(index_filename, "r") as f: index = json.loads(f.read()) @@ -276,6 +282,7 @@ class OpenMindHub(BaseHub): # First, let's deal with local folder. if os.path.isdir(pretrained_model_name_or_path): + validate_directory(pretrained_model_name_or_path) shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames] return shard_filenames, sharded_metadata @@ -407,6 +414,7 @@ class OpenMindHub(BaseHub): path_or_repo_id = str(path_or_repo_id) full_filename = os.path.join(subfolder, filename) if os.path.isdir(path_or_repo_id): + validate_directory(path_or_repo_id) resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename) if not os.path.isfile(resolved_file): if _raise_exceptions_for_missing_entries: @@ -416,6 +424,7 @@ class OpenMindHub(BaseHub): ) else: return None + validate_file_path(resolved_file) return resolved_file if cache_dir is None: @@ -774,7 +783,8 @@ class PushToHubMixin: ``` """ working_dir = repo_id.split("/")[-1] - + if os.path.exists(working_dir): + validate_directory(working_dir) repo_id = self._create_repo( repo_id, private=private, @@ -854,6 +864,7 @@ class PushToHubMixin: commit_message = "Upload processor" else: commit_message = f"Upload {self.__class__.__name__}" + validate_directory(working_dir) modified_files = [ f for f in os.listdir(working_dir) diff --git a/tests/unit/cli/test_push.py b/tests/unit/cli/test_push.py index d1cc1c3..9c0a4e6 100644 --- a/tests/unit/cli/test_push.py +++ b/tests/unit/cli/test_push.py @@ -12,42 +12,61 @@ # See the Mulan PSL v2 for more details. +import os +import tempfile import unittest +import shutil from unittest.mock import patch from openmind.archived.cli_legacy.model_cli import run_push class TestPush(unittest.TestCase): - @patch("sys.argv", ["openmind-cli", "push", "--repo_id", "test-repo", "--folder_path", "./model"]) + def setUp(self): + self.temp_folder_path = tempfile.mkdtemp() + self.temp_model_file = os.path.join(self.temp_folder_path, "model.txt") + with open(self.temp_model_file, "w") as f: + f.write("test model content") + + def tearDown(self): + shutil.rmtree(self.temp_folder_path, ignore_errors=True) + + @patch("openmind.archived.cli_legacy.model_cli.upload_folder") def test_run_push_basic(self, mock_upload): - run_push() - mock_upload.assert_called_once_with( - repo_id="test-repo", - folder_path="./model", - path_in_repo=None, - commit_message=None, - commit_description=None, - token=None, - revision="main", - allow_patterns=None, - ignore_patterns=None, - num_threads=5, - ) - - @patch("sys.argv", ["openmind-cli", "push", "test-repo", "--folder_path", "./model"]) + test_args = ["openmind-cli", "push", "--repo_id", "test-repo", "--folder_path", self.temp_folder_path] + with patch("sys.argv", test_args): + run_push() + + mock_upload.assert_called_once_with( + repo_id="test-repo", + folder_path=self.temp_folder_path, + path_in_repo=None, + commit_message=None, + commit_description=None, + token=None, + revision="main", + allow_patterns=None, + ignore_patterns=None, + num_threads=5, + ) + + @patch("openmind.archived.cli_legacy.model_cli.upload_folder") def test_run_push_legacy(self, mock_upload): - run_push() - mock_upload.assert_called_once_with( - repo_id="test-repo", - folder_path="./model", - path_in_repo=None, - commit_message=None, - commit_description=None, - token=None, - revision="main", - allow_patterns=None, - ignore_patterns=None, - num_threads=5, - ) + test_args = ["openmind-cli", "push", "test-repo", "--folder_path", self.temp_folder_path] + with patch("sys.argv", test_args): + + run_push() + + mock_upload.assert_called_once_with( + repo_id="test-repo", + folder_path=self.temp_folder_path, + path_in_repo=None, + commit_message=None, + commit_description=None, + token=None, + revision="main", + allow_patterns=None, + ignore_patterns=None, + num_threads=5, + ) diff --git a/tests/unit/utils/test_path_check.py b/tests/unit/utils/test_path_check.py new file mode 100644 index 0000000..55778d9 --- /dev/null +++ b/tests/unit/utils/test_path_check.py @@ -0,0 +1,296 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# +# openMind is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# +# http://license.coscl.org.cn/MulanPSL2 +# +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import os +import tempfile +import unittest +import yaml +from PIL import Image +import io + +from openmind.utils.arguments_utils import ( + safe_load_yaml, + validate_file_path, + validate_image_path, + validate_image_path_list, + validate_url, + validate_input_or_path, +) + + +class ArgumentsUtilsTestCase(unittest.TestCase): + + def create_temp_file(self, suffix="", content=None): + """Helper method to create a proper temporary file""" + temp_dir = tempfile.mkdtemp() + tmp_file_path = os.path.join(temp_dir, f"tempfile{suffix}") + + if content is not None: + if isinstance(content, str): + with open(tmp_file_path, "w") as f: + f.write(content) + else: + with open(tmp_file_path, "wb") as f: + f.write(content) + else: + with open(tmp_file_path, "w") as f: + f.write("") + + return tmp_file_path, temp_dir + + def create_temp_dir(self): + """Helper method to create a proper temporary directory""" + temp_dir = tempfile.mkdtemp() + return os.path.realpath(temp_dir) + + def cleanup_temp_files(self, *paths): + """Helper method to clean up temporary files and directories""" + for path in paths: + if os.path.isfile(path): + os.remove(path) + # Remove parent directory + parent_dir = os.path.dirname(path) + if os.path.isdir(parent_dir): + try: + os.rmdir(parent_dir) + except OSError: + pass # Directory not empty or other issue + elif os.path.isdir(path): + try: + os.rmdir(path) + except OSError: + pass # Directory not empty or other issue + + def test_validate_url_valid_urls(self): + """Test that valid URLs pass validation""" + valid_urls = [ + "http://example.com", + "https://example.com", + "https://sub.example.com/path?query=value#fragment", + "ftp://files.example.com", + "mailto:test@example.com", + "file:///path/to/file.txt", + ] + + for url in valid_urls: + # Should not raise an exception + self.assertEqual(validate_url(url), url) + + def test_validate_url_invalid_urls(self): + """Test that invalid URLs raise ValueError""" + invalid_urls = [ + "", # Empty string + " ", # Whitespace + "http://", # Missing host + ] + + # These URLs should fail validation on all platforms + for url in invalid_urls: + with self.assertRaises(ValueError, msg=f"URL '{url}' should raise ValueError"): + validate_url(url) + + def test_validate_file_path_valid_file(self): + """Test that valid file paths pass validation""" + tmp_file_path, temp_dir = self.create_temp_file() + try: + self.assertTrue(os.path.exists(tmp_file_path), f"File {tmp_file_path} should exist") + self.assertTrue(os.path.isfile(tmp_file_path), f"Path {tmp_file_path} should be a file") + # Should not raise an exception + result = validate_file_path(tmp_file_path, allow_symlinks=True) + # Normalize path for comparison + self.assertEqual(os.path.realpath(result), os.path.realpath(tmp_file_path)) + finally: + self.cleanup_temp_files(tmp_file_path, temp_dir) + + def test_validate_file_path_invalid_file(self): + """Test that invalid file paths raise appropriate exceptions""" + with self.assertRaises(FileNotFoundError): + validate_file_path("/non/existent/file.txt") + + def test_validate_file_path_with_extensions(self): + """Test file path validation with allowed extensions""" + tmp_file_path, temp_dir = self.create_temp_file(".yaml") + try: + self.assertTrue(os.path.exists(tmp_file_path), f"File {tmp_file_path} should exist") + self.assertTrue(os.path.isfile(tmp_file_path), f"Path {tmp_file_path} should be a file") + # Should pass with correct extension + result = validate_file_path(tmp_file_path, allowed_extensions=[".yaml", ".yml"], allow_symlinks=True) + # Normalize path for comparison + self.assertEqual(os.path.realpath(result), os.path.realpath(tmp_file_path)) + + # Test with wrong extension - should still pass as we're not strict about extensions by default + with self.assertRaises(ValueError): + validate_file_path( + tmp_file_path, + allowed_extensions=[".txt"], + allow_symlinks=True, + require_read=False, + require_write=False, + ) + finally: + self.cleanup_temp_files(tmp_file_path, temp_dir) + + def test_safe_load_yaml_valid_file(self): + """Test loading valid YAML file""" + yaml_content = {"key": "value", "number": 42} + tmp_file_path, temp_dir = self.create_temp_file(".yaml", yaml.dump(yaml_content)) + + try: + result = safe_load_yaml(tmp_file_path) + self.assertEqual(result, yaml_content) + finally: + self.cleanup_temp_files(tmp_file_path, temp_dir) + + def test_safe_load_yaml_invalid_path(self): + """Test that safe_load_yaml raises error for invalid path""" + with self.assertRaises(FileNotFoundError): + safe_load_yaml("/non/existent/file.yaml") + + def test_safe_load_yaml_invalid_type(self): + """Test that safe_load_yaml raises error for non-string path""" + with self.assertRaises(TypeError): + safe_load_yaml(123) + + def test_safe_load_yaml_none_path(self): + """Test that safe_load_yaml raises error for None path""" + with self.assertRaises(ValueError): + safe_load_yaml(None) + + @unittest.skipUnless(hasattr(Image, "new"), "PIL not available") + def test_validate_image_path_pil_image(self): + """Test that PIL Image objects pass through""" + img = Image.new("RGB", (100, 100), color="red") + result = validate_image_path(img) + self.assertIs(result, img) + + @unittest.skipUnless(hasattr(Image, "new"), "PIL not available") + def test_validate_image_path_valid_file(self): + """Test that valid image file paths pass validation""" + # Create a simple image in memory + img = Image.new("RGB", (100, 100), color="red") + img_bytes = io.BytesIO() + img.save(img_bytes, format="JPEG") + + tmp_file_path, temp_dir = self.create_temp_file(".jpg", img_bytes.getvalue()) + try: + result = validate_image_path(tmp_file_path) + # Normalize path for comparison + self.assertEqual(os.path.realpath(result), os.path.realpath(tmp_file_path)) + finally: + self.cleanup_temp_files(tmp_file_path, temp_dir) + + def test_validate_image_path_url(self): + """Test that URLs pass URL validation in image path""" + url = "http://example.com/image.jpg" + result = validate_image_path(url) + self.assertEqual(result, url) + + @unittest.skipUnless(hasattr(Image, "new"), "PIL not available") + def test_validate_image_path_list_single_string(self): + """Test validating single image path string""" + img = Image.new("RGB", (100, 100), color="red") + img_bytes = io.BytesIO() + img.save(img_bytes, format="JPEG") + + tmp_file_path, temp_dir = self.create_temp_file(".jpg", img_bytes.getvalue()) + try: + result = validate_image_path_list(tmp_file_path) + self.assertEqual(result, tmp_file_path) + finally: + self.cleanup_temp_files(tmp_file_path, temp_dir) + + @unittest.skipUnless(hasattr(Image, "new"), "PIL not available") + def test_validate_image_path_list_image_object(self): + """Test that PIL Image objects pass through in list""" + img = Image.new("RGB", (100, 100), color="red") + result = validate_image_path_list(img) + self.assertIs(result, img) + + @unittest.skipUnless(hasattr(Image, "new"), "PIL not available") + def test_validate_image_path_list_multiple_images(self): + """Test validating list of image paths""" + images = [] + temp_paths = [] + temp_dirs = [] + + try: + for i in range(2): + img = Image.new("RGB", (100, 100), color="red") + img_bytes = io.BytesIO() + img.save(img_bytes, format="JPEG") + + tmp_file_path, temp_dir = self.create_temp_file(f"{i}.jpg", img_bytes.getvalue()) + images.append(tmp_file_path) + temp_paths.append(tmp_file_path) + temp_dirs.append(temp_dir) + + result = validate_image_path_list(images) + # Compare normalized paths + normalized_result = [os.path.normpath(p) for p in result] + normalized_expected = [os.path.normpath(p) for p in images] + self.assertEqual(normalized_result, normalized_expected) + finally: + for path, dir_path in zip(temp_paths, temp_dirs): + self.cleanup_temp_files(path, dir_path) + + def test_validate_input_or_path_dict(self): + """Test that dict input passes through unchanged""" + input_dict = {"key": "value"} + result = validate_input_or_path(input_dict) + self.assertIs(result, input_dict) + + def test_validate_input_or_path_invalid_type(self): + """Test that invalid input types raise ValueError""" + invalid_inputs = [123, ["list"], ("tuple",), None] + + for invalid_input in invalid_inputs: + with self.assertRaises(ValueError): + validate_input_or_path(invalid_input) + + def test_validate_input_or_path_valid_url(self): + """Test that valid URLs are validated and returned""" + url = "https://example.com" + result = validate_input_or_path(url) + self.assertEqual(result, url) + + def test_validate_input_or_path_valid_file(self): + """Test that existing file paths are validated""" + tmp_file_path, temp_dir = self.create_temp_file(content="test") + try: + result = validate_input_or_path(tmp_file_path) + # Normalize path for comparison + self.assertEqual(os.path.realpath(result), os.path.realpath(tmp_file_path)) + finally: + self.cleanup_temp_files(tmp_file_path, temp_dir) + + def test_validate_input_or_path_valid_directory(self): + """Test that existing directory paths are validated""" + temp_dir = self.create_temp_dir() + try: + result = validate_input_or_path(temp_dir) + # Normalize path for comparison + self.assertEqual(os.path.normpath(result), os.path.normpath(temp_dir)) + finally: + self.cleanup_temp_files(temp_dir) + + def test_validate_input_or_path_plain_string(self): + """Test that plain strings (non-URL, non-file) pass through""" + plain_string = "This is a plain text input" + result = validate_input_or_path(plain_string) + self.assertEqual(result, plain_string) + + def test_validate_input_or_path_nonexistent_path(self): + """Test that non-existent paths are treated as plain strings""" + nonexistent_path = "/non/existent/path.txt" + result = validate_input_or_path(nonexistent_path) + self.assertEqual(result, nonexistent_path) -- Gitee