diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py index c6d611d5c37fb279e951868899d54c16ada24f12..2bca42656017acfa982544100c5cd59b730e9640 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -13,11 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import importlib import types import torch +from msprobe.core.common.log import logger from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.pytorch.common.utils import torch_version_above_or_equal_2 + if torch_version_above_or_equal_2: from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks @@ -70,7 +74,56 @@ def wrap_compile_script_func(): _cf_mod.convert_frame = _patched_convert_frame +def patch_dynamo_compile(): + cf = importlib.import_module("torch._dynamo.convert_frame") + if not hasattr(cf, "_compile"): + logger.warning("No found torch._dynamo.convert_frame._compile") + + original = cf._compile + if getattr(original, "__msprobe_patched__", False): + return + + @functools.wraps(original) + def wrapped(*args, **kwargs): + result = None + try: + reg = get_api_register() + reg.restore_all_api() + except Exception as e: + logger.warning(f"[msprobe] Pre restore_all_api failed: {e}") + return result + + try: + result = original(*args, **kwargs) + except Exception: + logger.warning("[msprobe] _compile execution failed (returning None)") + result = None + finally: + try: + reg = get_api_register() + reg.register_all_api() # 改成注册hook + except Exception as e: + logger.warning(f"[msprobe] Post register_all_api failed: {e}") + return result + wrapped.__msprobe_patched__ = True + wrapped.__msprobe_original__ = original + cf._compile = wrapped + + +def unpatch_dynamo_compile() -> bool: + # 预留取消patch接口 + cf = importlib.import_module("torch._dynamo.convert_frame") + current = getattr(cf, "_compile", None) + if current is None: + return False + original = getattr(current, "__msprobe_original__", None) + if original is None: + return False + cf._compile = original + return True + + def wrap_script_func(): wrap_jit_script_func() if torch_version_above_or_equal_2: - wrap_compile_script_func() \ No newline at end of file + patch_dynamo_compile()