diff --git a/mindformers/core/config_args.py b/mindformers/core/config_args.py index 6c55eb3fb3130279b2f3868ef3389160eaf8ac20..82652d5570bdabafea8ba1e8664f5adb252e9ab7 100644 --- a/mindformers/core/config_args.py +++ b/mindformers/core/config_args.py @@ -467,7 +467,7 @@ class ContextConfig(BaseArgsConfig): save_graphs_path: str = ".", **kwargs, ): - super(ContextConfig, self).__init__( + super().__init__( mode=mode, device_id=device_id, device_target=device_target, @@ -499,6 +499,7 @@ class MFContextConfig(BaseArgsConfig): 'resume_training', 'use_graceful_exit', 'affinity_cpu_list', + 'affinity_config', 'monitor_local_loss', 'monitor_device_local_loss', 'profile', @@ -515,7 +516,7 @@ class MFContextConfig(BaseArgsConfig): use_graceful_exit: bool = False, **kwargs, ): - super(MFContextConfig, self).__init__( + super().__init__( exclude_cann_cpu=exclude_cann_cpu, postprocess_use_numpy=postprocess_use_numpy, use_graceful_exit=use_graceful_exit, @@ -721,7 +722,7 @@ class ParallelContextConfig(BaseArgsConfig): auto_pipeline: bool = False, **kwargs ): - super(ParallelContextConfig, self).__init__( + super().__init__( parallel_mode=parallel_mode, device_num=device_num, gradients_mean=gradients_mean, @@ -747,7 +748,7 @@ class ParallelConfig(BaseArgsConfig): # pylint: disable=W0235 def __init__(self, **kwargs): - super(ParallelConfig, self).__init__(**kwargs) + super().__init__(**kwargs) @dataclass @@ -788,7 +789,7 @@ class ConfigArguments(BaseArgsConfig): save_checkpoint: Optional[Union[dict, BaseArgsConfig]] = None, cloud_config: Optional[Union[dict, BaseArgsConfig]] = None, ): - super(ConfigArguments, self).__init__( + super().__init__( output_dir=output_dir, profile=profile, auto_tune=auto_tune, diff --git a/mindformers/core/context/build_context.py b/mindformers/core/context/build_context.py index 1f2109a9989dbc88fd5d2fce8d9935b004b42fe6..89b0371f23bf21b0facca39eb2d4fba6225ea7b9 100644 --- a/mindformers/core/context/build_context.py +++ b/mindformers/core/context/build_context.py @@ -35,7 +35,8 @@ from mindformers.tools.register import MindFormerConfig from mindformers.tools.utils import ( MODE, get_output_subpath, - check_in_dynamic_cluster + check_in_dynamic_cluster, + get_real_local_rank ) from mindformers.utils import get_cann_workqueue_cores from mindformers.version_control import ( @@ -64,6 +65,7 @@ class Context: self.ms_ctx_opr = MSContextOperator(self.config) self.parallel_opr = ParallelOperator(self.config) if check_tft_valid() and ("ARF:1" in os.getenv("MS_ENABLE_TFT", "")): + # pylint: disable=C0415 from mindspore.utils import _tft_handler _tft_handler.init(config=self.config) if self.config.use_parallel: @@ -71,7 +73,26 @@ class Context: self.parallel_opr.init_communication() ) set_cpu_affinity(self.rank_id, self.device_num) - set_ms_affinity(self.config.get('context', {}).get('affinity_cpu_list', {})) + + context_config = self.config.get('context', {}) + has_affinity_cpu_list = 'affinity_cpu_list' in context_config + has_affinity_config = 'affinity_config' in context_config + if has_affinity_config: + if has_affinity_cpu_list: + logger.warning('affinity_cpu_list will be removed in the near future, ' + 'affinity_config is taking effect.') + affinity_config = context_config.get('affinity_config', {}) + set_ms_affinity(affinity_config, None) + elif has_affinity_cpu_list: + logger.warning('affinity_cpu_list will be removed in the near future, ' + 'please use affinity_config instead.') + affinity_cpu_list = context_config.get('affinity_cpu_list', {}) + if not isinstance(affinity_cpu_list, dict): + logger.warning(f'custom bind policy affinity_cpu_list must be dict, but got {affinity_cpu_list}.') + else: + set_ms_affinity(None, affinity_cpu_list) + else: + set_ms_affinity(None, None) self._initailed = True @@ -220,7 +241,7 @@ class MFContextOperator(MFContextConfig): self.config = config supported_kwargs = self._handle_data() logger.debug('MFContextConfig load configs: %s', supported_kwargs) - super(MFContextOperator, self).__init__(**supported_kwargs) + super().__init__(**supported_kwargs) use_past = self.config.get_value('model.model_config.use_past', False) if not hasattr(self, 'train_precision_sync'): @@ -275,10 +296,9 @@ class MFContextOperator(MFContextConfig): f"LCCL_DETERMINISTIC: {os.getenv('LCCL_DETERMINISTIC')}") return '', '' raise e - else: - if deterministic: - return 'off', '1' - return 'on', '0' + if deterministic: + return 'off', '1' + return 'on', '0' def _get_precision_env(self): """Set deterministic computing and get relative env variable.""" @@ -287,7 +307,7 @@ class MFContextOperator(MFContextConfig): run_mode = getattr(self, 'run_mode') if hasattr(self, 'run_mode') else None if run_mode in ( RunMode.TRAIN.value, RunMode.FINETUNE.value - ) and self.train_precision_sync is not None: + ) and self.train_precision_sync is not None: _, _ = self._call_ms_deterministic(self.train_precision_sync) if run_mode == RunMode.PREDICT.value and self.infer_precision_sync is not None: @@ -357,12 +377,31 @@ class MFContextOperator(MFContextConfig): return getattr(self, attr_key, None) -def set_ms_affinity(affinity_config): - """set mindspore cpu affinity""" - if not isinstance(affinity_config, dict): - logger.warning(f'custom bind policy affinity_cpu_list must be dict, but got {affinity_config}.') - return - ms.runtime.set_cpu_affinity(True, affinity_config) +def set_ms_affinity(affinity_config, affinity_cpu_list): + """ + Set mindspore cpu affinity. Expecting one of the arguments is None. + If both have values, affinity_cpu_list will be set to None + """ + if affinity_config and affinity_cpu_list: + affinity_cpu_list = None + + if affinity_config: + device_id = get_real_local_rank() + device_config = affinity_config.get(f'device_{device_id}', None) + if device_config: + affinity_cpu_list = device_config.get('affinity_cpu_list', None) + module_to_cpu_dict = device_config.get('module_to_cpu_dict', None) + else: + affinity_cpu_list = None + module_to_cpu_dict = None + else: + module_to_cpu_dict = None + + ms.runtime.set_cpu_affinity( + True, + affinity_cpu_list, + module_to_cpu_dict + ) def set_cpu_affinity(rank_id, rank_size):