diff --git a/src/openmind/cli/cli.py b/src/openmind/cli/cli.py index ba4798d52be22e59d31a53796bf5d4e02abe8030..31d5d349dee3a9daf81af0563ec87d44b69244c7 100644 --- a/src/openmind/cli/cli.py +++ b/src/openmind/cli/cli.py @@ -18,36 +18,52 @@ import os import subprocess import sys import random +import importlib from openmind.utils.constants import Command -from openmind.cli.chat import run_chat -from openmind.cli.env import run_env -from openmind.archived.cli_legacy.model_cli import run_pull, run_push, run_rm, run_list -from openmind.archived.cli_legacy.pipeline_cli import run_pipeline -from openmind.utils import is_torch_available - -# Compatible with MindSpore -if is_torch_available(): - import torch - from accelerate import PartialState - from openmind.cli import train - from openmind.cli.export import run_export - from openmind.cli.deploy import run_deploy - from openmind.cli.eval import run_eval +from openmind.utils.import_utils import ( + is_torch_available, +) +from openmind.utils.dependency_utils import check_dependencies + +COMMAND_DEPENDENCIES = { + Command.TRAIN: ["torch", "transformers"], + Command.EVAL: ["torch", "transformers"], + Command.EXPORT: ["torch", "transformers"], + Command.DEPLOY: ["torch"], +} def get_device_count(): - state = PartialState() - device_module = getattr(torch, state.device.type.lower(), None) - if device_module and hasattr(device_module, "device_count"): - return device_module.device_count() + if is_torch_available(): + import torch + from accelerate import PartialState + + state = PartialState() + device_module = getattr(torch, state.device.type.lower(), None) + if device_module and hasattr(device_module, "device_count"): + return device_module.device_count() return 0 def main(): + if len(sys.argv) < 2: + print("Usage: openmind-cli [options]") + print("\nAvailable commands:") + for cmd in sorted(dir(Command)): + if not cmd.startswith("_"): + print(f" {getattr(Command, cmd).lower()}") + return + command_cli = sys.argv[1] + + if command_cli in COMMAND_DEPENDENCIES: + check_dependencies(COMMAND_DEPENDENCIES[command_cli], f"openmind-cli {command_cli}") + if command_cli == Command.TRAIN: if get_device_count() >= 1: + train_module_path = importlib.util.find_spec("openmind.cli.train").origin + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) command = [ @@ -62,30 +78,50 @@ def main(): master_addr, "--master_port", master_port, - train.__file__, + train_module_path, ] + sys.argv[2::] subprocess.run(command) else: raise ValueError("There is no npu devices to launch finetune workflow") elif command_cli == Command.LIST: + from openmind.archived.cli_legacy.model_cli import run_list + run_list() elif command_cli == Command.EVAL: + from openmind.cli.eval import run_eval + run_eval() elif command_cli == Command.PULL: + from openmind.archived.cli_legacy.model_cli import run_pull + run_pull() elif command_cli == Command.PUSH: + from openmind.archived.cli_legacy.model_cli import run_push + run_push() elif command_cli == Command.RM: + from openmind.archived.cli_legacy.model_cli import run_rm + run_rm() elif command_cli == Command.CHAT: + from openmind.cli.chat import run_chat + run_chat() elif command_cli == Command.RUN: + from openmind.archived.cli_legacy.pipeline_cli import run_pipeline + run_pipeline() elif command_cli == Command.ENV: + from openmind.cli.env import run_env + run_env() elif command_cli == Command.DEPLOY: + from openmind.cli.deploy import run_deploy + run_deploy() elif command_cli == Command.EXPORT: + from openmind.cli.export import run_export + run_export() else: raise ValueError(f"Currently command {command_cli} is not supported") diff --git a/src/openmind/cli/train.py b/src/openmind/cli/train.py index 024efab40f99767d82f503d29447fb2f73cf89ab..fe471ac7e34e7c02df5aeb61a6c31a18d06c071f 100644 --- a/src/openmind/cli/train.py +++ b/src/openmind/cli/train.py @@ -17,6 +17,15 @@ from openmind.flow.arguments import get_args, initialize_openmind from openmind.flow.train import run_sft, run_pt, run_dpo, run_rm from openmind.flow.callbacks import get_swanlab_callbacks from openmind.utils.constants import Stages +from openmind.utils.dependency_utils import check_dependencies + + +STAGE_DEPENDENCIES = { + Stages.SFT: ["transformers"], + Stages.PT: ["transformers"], + Stages.DPO: ["transformers", "trl"], + Stages.RM: ["transformers", "trl"], +} def run_train(**kwargs): @@ -29,6 +38,9 @@ def run_train(**kwargs): initialize_openmind(yaml_file, **kwargs) args = get_args() + if args.stage in STAGE_DEPENDENCIES: + check_dependencies(STAGE_DEPENDENCIES[args.stage], f"Training stage {args.stage}") + callbacks = get_swanlab_callbacks() if args.stage == Stages.SFT: diff --git a/src/openmind/utils/__init__.py b/src/openmind/utils/__init__.py index 9e58203ca00d6aff7bcbdff0b960f9897f079bc2..afb80f7e2bb0923fc553ce548ff6d068f5b77d7f 100644 --- a/src/openmind/utils/__init__.py +++ b/src/openmind/utils/__init__.py @@ -29,6 +29,7 @@ __all__ = [ "is_pyav_available", "logging", "get_logger", + "check_dependencies", ] from .import_utils import ( @@ -57,3 +58,5 @@ from .import_utils import ( ) from .logging import get_logger + +from .dependency_utils import check_dependencies diff --git a/src/openmind/utils/dependency_utils.py b/src/openmind/utils/dependency_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58bca57632e8e739eb18a2ff38190548d89b122e --- /dev/null +++ b/src/openmind/utils/dependency_utils.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# +# openMind is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# +# http://license.coscl.org.cn/MulanPSL2 +# +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + + +from openmind.utils.import_utils import ( + is_torch_available, + is_transformers_available, + is_trl_available, + PYTORCH_IMPORT_ERROR, + TRANSFORMERS_IMPORT_ERROR, + TRL_IMPORT_ERROR, +) + +# Common dependencies for modules +COMMON_DEPENDENCIES = { + "torch": (is_torch_available, PYTORCH_IMPORT_ERROR), + "transformers": (is_transformers_available, TRANSFORMERS_IMPORT_ERROR), + "trl": (is_trl_available, TRL_IMPORT_ERROR), +} + + +def check_dependencies(dependencies, context=""): + """ + Common dependencies checker + + Args: + dependencies: dependencies list or dict + context: Context for error message + """ + if isinstance(dependencies, list): + deps_dict = {dep: COMMON_DEPENDENCIES[dep] for dep in dependencies if dep in COMMON_DEPENDENCIES} + else: + deps_dict = dependencies + + for _, (check_func, error_msg) in deps_dict.items(): + if not check_func(): + raise ImportError(error_msg.format(context)) + return True diff --git a/src/openmind/utils/import_utils.py b/src/openmind/utils/import_utils.py index 9f4474fd5b8cfc72ed99323eecd123a8d1310872..fb79ad88298ab7197031b474b49c07c6691432e1 100644 --- a/src/openmind/utils/import_utils.py +++ b/src/openmind/utils/import_utils.py @@ -214,6 +214,11 @@ FRAMEWORK_NOT_FOUND_ERROR = """ framework you want to use and note that you may need to restart your runtime after installation. """ +TRL_IMPORT_ERROR = """ +{0} requires the TRL library but it was not found in your environment. You can install it with pip: +`pip install trl`. Please note that you may need to restart your runtime after installation. +""" + BACKENDS_MAPPING = OrderedDict( [ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),