diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 9c659a22df65cc4aa8b2843a950687e5e5fd82aa..a724c291060298fc7001adfc1037bdd419794c4c 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -797,6 +797,8 @@ class MonitorConst: ) DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer" RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan'] + L2_HOOKS = ["linear_hook", "attention_hook"] + SA_ORDERS = ["s,b,h,d", "b,s,h,d"] SLICE_SIZE = 20480 # used for name diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py index f19e14d89e6b7cb29d8bdc756a5d258081b106ab..2c7a20a040aa47762fec1b1ca37944fe64d1c793 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -96,8 +96,33 @@ def validate_targets(targets): raise TypeError('key of targets should be module_name[str] in config.json') if not isinstance(field, dict): raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') + - +def validate_l2_targets(targets): + if not isinstance(targets, dict): + raise TypeError('l2_targets in config.json should be a dict') + for hook_name, target_list in targets.items(): + if hook_name not in MonitorConst.L2_HOOKS: + raise TypeError(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}, got {hook_name}') + if not isinstance(target_list, list): + raise TypeError('values of l2_targets should be a list in config.json') + for item in target_list: + if not isinstance(item, str): + raise TypeError(f'item of "{hook_name}" in l2_targets should be module_name[str] in config.json') + + +def validate_recording_l2_features(recording_l2_features): + if not isinstance(recording_l2_features, bool): + raise TypeError("recording_l2_features should be a bool") + + +def validate_sa_order(sa_order): + if isinstance(sa_order, str): + sa_order = sa_order.replace(' ', '') + if sa_order not in MonitorConst.SA_ORDERS: + raise TypeError(f'sa_order must be in {MonitorConst.SA_ORDERS}, got {sa_order}') + + def validate_print_struct(print_struct): if not isinstance(print_struct, bool): raise TypeError("print_struct should be a bool") @@ -216,6 +241,15 @@ def validate_config(config): targets = config.get("targets", {}) validate_targets(targets) + l2_targets = config.get("l2_targets", {}) + validate_l2_targets(l2_targets) + + recording_l2_features = config.get("recording_l2_features", False) + validate_recording_l2_features(recording_l2_features) + + sa_order = config.get("sa_order", "s,b,h,d") + validate_sa_order(sa_order) + print_struct = config.get('print_struct', False) validate_print_struct(print_struct) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 3d4be725da17a333a8bd19a3bc8069bc291ba247..69b9e5f23b23b3a58c559e2f6008167a7cde98b4 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -24,6 +24,7 @@ | [采集module堆栈信息](#采集module堆栈信息) | 采集监控的第一个 step 的 module 对应的堆栈信息辅助问题定位 | PyTorch、MindSpore | | [指定监控对象](#指定监控对象) | 指定监控的nn.Module(nn.Cell)及对应的输入输出 | PyTorch、MindSpore | | [打印模型结构](#打印模型结构) | 打印模型结构 | PyTorch | +| [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch | | [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`、`ndigits`均支持 | PyTorch、MindSpore | | [mbs粒度梯度监控](#mbs粒度梯度监控) | 开启梯度监控时,采集聚合前梯度时支持`micro_batch_size`粒度 | PyTorch、MindSpore | | [异常告警](#异常告警) | 监控对象指标异常时自动告警,支持异常数据落盘 | PyTorch、MindSpore | @@ -302,6 +303,34 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。 } ``` +### l2可解释特征监控 +- 工具配置示例 +```json +{ + "l2_targets": { + "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], + "linear_hook": ["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"] + }, + "recording_l2_features": true, + "sa_order": "b,s,h,d" +} +``` +| 配置项 | 类型 | 说明 | 是否必选 | +|--------|------|------|--------| +| **l2_targets** | Dict[str, List[str]] | 指定需要监控的模型层配置
**支持的hook类型**:
• `attention_hook`:监控注意力层
  ▪️ 采集指标:`entropy` `softmax_max`
  ▪️ 必须通过[打印模型结构](#打印模型结构)获取准确层名
  ▪️ 不配置或配置空列表均表示不采集
• `linear_hook`:监控线性层
  ▪️ 采集指标:`sr`, `kernel_norm`
  ▪️ 必须通过[打印模型结构](#打印模型结构)获取准确层名, 不配置表示不采集
  ▪️ 配置空列表会自动识别符合条件的层(包含`weight`或`wg`2D参数属性的层) | 是 | +| **recording_l2_features** | bool | 是否开启L2层特征数据采集,默认为false表示不采集 | 否 | +| **sa_order** | str | 计算`attention_hook`内指标时,指定Attention输入(Q,K)的张量维度排列顺序,支持"s,b,h,d"和"b,s,h,d", 默认为"s,b,h,d"表示输入维度顺序为**s**equence_len​->**b**atch_size​->num_**h**eads​->head_**d**im | 否 | + + +#### L2可解释特征监控指标说明 + +| **指标名称** | **适用Hook类型** | **数学定义/计算方式** | **监控意义** | +|--------------------|-------------------|-------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| +| **entropy** | attention_hook | $H(p)=-\sum p_i \log p_i$,其中$p_i$为注意力权重 | 衡量注意力分布的不确定性,**低熵值**表示注意力集中 | +| **softmax_max** | attention_hook | $\max(\text{softmax}(QK^T/\sqrt{d}))$ | 反映注意力机制的聚焦程度,**高值**表示存在显著主导的注意力token | +| **sr(stable_rank)** | linear_hook | $\frac{\|W\|_F}{\|W\|_2}$(稳定秩,Frobenius范数除以谱范数) | 评估权重矩阵的有效秩,**低值**表示矩阵接近低秩不稳定状态 | +| **kernel_norm** | linear_hook | $\|W\|_F$(Frobenius范数) | 权重矩阵的缩谱范数,反映输入在矩阵最大奇异向量张成空间的放大系数 | + ### 输出格式和统计量 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py index cfd0c1615d2072147d12e7e279527bef86c9d497..960f3dabe32acb8c8ed3018f7421747131e02984 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py @@ -111,3 +111,97 @@ def cal_histc(tensor_cal, bins_total, min_val, max_val): @torch.no_grad() def get_nans(t): return torch.isnan(t).sum() + + +def check_tensor_dim(tensor, n): + """检查张量维度是否大于n + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Input must be a PyTorch tensor. Got {type(tensor)} instead. " + f"Consider using torch.tensor() for conversion." + ) + + if tensor.dim() < n: + raise ValueError( + f"Tensor must have at least {n} dimensions. " + f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims." + ) + + +@torch.no_grad() +def max_eigenvalue(input_tensor: torch.Tensor, num_iterations=3): + input_tensor = input_tensor.float() + try: + check_tensor_dim(input_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate max eigenvalue failed: {e}") + return torch.tensor(0) + in_features = input_tensor.shape[1] + u_tensor = torch.randn(in_features).to(input_tensor.device) + u_norm = u_tensor.norm() + if u_norm.item() == 0: + return torch.tensor(0) + u_tensor = u_tensor / u_tensor.norm() + input_seq = torch.matmul(input_tensor.T, input_tensor) + for _ in range(num_iterations): + v_tensor = torch.matmul(input_seq, u_tensor) + spectral_norm = torch.matmul(v_tensor.T, u_tensor) + v_norm = v_tensor.norm() + if v_norm > 0: + u_tensor = v_tensor / v_norm + else: + spectral_norm = torch.tensor(0) + break + return spectral_norm.sqrt() + + +@torch.no_grad() +def cal_entropy(qk_tensor, mask=None): + try: + check_tensor_dim(qk_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate max eigenvalue failed: {e}") + return torch.tensor(0), torch.tensor(0) + if mask is None: + mask = torch.tril(torch.ones(qk_tensor.shape[1], qk_tensor.shape[1])).to( + qk_tensor.device) + qk_tensor = qk_tensor - torch.amax(qk_tensor, dim=1, keepdim=True) + qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf')) + softmax_qkt = torch.nn.functional.softmax(qk_tensor.float(), dim=1) + # softmax取QK矩阵最大值 + softmax_max = torch.mean(torch.amax(softmax_qkt, dim=1)) + entropy = torch.mean(-torch.nansum(softmax_qkt * + torch.log(softmax_qkt), dim=1)) + return entropy, softmax_max + + +@torch.no_grad() +def cal_qkt(q_h, k_h, order="s,b,h,d"): + # q_h shape is [s, b, h, d] + try: + check_tensor_dim(q_h, 4) + check_tensor_dim(k_h, 4) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate qk tensor failed: {e}") + return torch.tensor(0) + + if order == "s,b,h,d": + qkt = torch.matmul( + q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5 + elif order == "b,s,h,d": + qkt = torch.matmul( + q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5 + else: + logger.warning("Calculate qk tensor failed: Order unsupported.") + qkt = torch.tensor(0) + return qkt + + +@torch.no_grad() +def cal_stable_rank(weight: torch.Tensor): + eig = max_eigenvalue(weight) + if eig == torch.tensor(0): + return torch.tensor(0), torch.tensor(0) + f_norm = torch.norm(weight, p="fro") + return f_norm / eig, eig diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index f39cd7a83384d4f9b3342d77a0f6183b051c7089..0a2cd447c9fc32fb8e6ebcb2aa1baab05577c637 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -40,9 +40,9 @@ from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ get_process_group -from msprobe.pytorch.monitor.features import get_sign_matches +from msprobe.pytorch.monitor.features import get_sign_matches, cal_qkt from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \ - TensorMetrics, squash_param_name + TensorMetrics, squash_param_name, get_entropy_metric, get_sr_metric from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer @@ -57,6 +57,7 @@ FORMAT_MAPPING = { MonitorConst.CSV: CSVWriterWithAD, MonitorConst.API: BaseWriterWithAD } +start_step = 0 def param_is_not_tensor_parallel_duplicate(param, tp_group): @@ -83,7 +84,17 @@ class ModuleHookContext: self.actvgrad.clear() -start_step = 0 +class FeatureHookContext: + def __init__(self, module_name): + self.step = 0 + self.micro_step = 0 + self.attention_feature = {} + self.linear_feature = {} + self.module_name = module_name + + def reset(self): + self.attention_feature.clear() + self.linear_feature.clear() class OptimizerContext: @@ -206,6 +217,7 @@ class TrainerMon: # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量 self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.feature_hook_context_by_module = defaultdict(FeatureHookContext) self.optimizer_context = defaultdict(OptimizerContext) self.cc_context = defaultdict(CommunicationContext) self.grad_context = GradContext() @@ -274,6 +286,18 @@ class TrainerMon: cc_tensor.reset() return metrics + @staticmethod + def get_linear_hook_target(module): + if isinstance(module, torch.nn.Embedding): + return '' + if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"): + return '' + for weight_name in ["weight", "wg"]: + if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): + if getattr(module, weight_name).dim() == 2: + return weight_name + return '' + def set_config(self): logger.info(f"current config: {self.config}") self.start_step = self.config.get("start_step", 0) @@ -298,6 +322,8 @@ class TrainerMon: self.cc_distribution = self.config.get("cc_distribution", {}) self.stack_info = self.config.get('stack_info', False) self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False) + self.recording_l2_features = self.config.get("recording_l2_features", False) + self.sa_order = self.config.get("sa_order", "s,b,h,d") if not self.cc_distribution.get('enable', False): self.cc_log_only = False @@ -356,6 +382,8 @@ class TrainerMon: logger.info_on_rank_0("> momentum and variance of adam is not monitored. ") if not self.wg_distribution: logger.info_on_rank_0("> weight grad of specified module is not monitored. ") + if not self.recording_l2_features: + logger.info_on_rank_0("> l2 features of specified module is not monitored. ") if not self.mg_direction: logger.info_on_rank_0('> grad and momentum direction will not be compared.') if not self.cc_distribution.get('enable', False): @@ -537,6 +565,24 @@ class TrainerMon: if self.grad_context.actv: self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) + def write_metrics_if_not_empty(self, features, metrics, step, hook_name): + if not features or len(features) == 0: + return + use_micro_step = hook_name not in ["linear_hook"] + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=use_micro_step) + features.clear() + + def write_features_tb(self, step): + if not self.recording_l2_features: + return + for context in self.feature_hook_context_by_module.values(): + num_features = len(context.attention_feature) + len(context.linear_feature) + if num_features == 0: + continue + self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax_max"], + step, "attention_hook") + self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step, "linear_hook") + def write_param_tb(self, opt_context): if not self.param_distribution: return @@ -691,6 +737,7 @@ class TrainerMon: if self.anomaly_data_factory: self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) + self.write_features_tb(context.step) self.write_grad_tb(context.step) self.write_mv_tb(context) self.write_param_tb(context) @@ -760,7 +807,8 @@ class TrainerMon: vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}' targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[ 'targets'].keys() - hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + l2_target_names = self.config.get('l2_targets', '') + hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage) logger.info_on_rank_0(f"> {hooked_count} modules are monitored.") @@ -801,6 +849,9 @@ class TrainerMon: for handle in self.handles['xy']: handle.remove() self.handles['xy'].clear() + for handle in self.handles['L2_features']: + handle.remove() + self.handles['L2_features'].clear() # 清空对应context缓存 for _, fwd_context in self.module_fwd_hook_context_by_module.items(): fwd_context.reset() @@ -941,7 +992,20 @@ class TrainerMon: return pattern return "" - def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''): + def _is_recording_module(self, module_name, l2_targets, vpp_stage, hook_name): + + if len(l2_targets) > 0: + for pattern in [ + vpp_stage + squash_param_name(module_name, self.squash_name), + vpp_stage + module_name, + ]: + if pattern in l2_targets: + return pattern + elif hook_name in ["linear_hook"]: + return vpp_stage + squash_param_name(module_name, self.squash_name) + return "" + + def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''): if '_modules' not in module.__dict__: # nothing to hook return 0 @@ -1020,6 +1084,61 @@ class TrainerMon: context.micro_step = 0 return + def extract_attention_feature_hook(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + tbtag_tensor_map = {} + if len(module_input) < 2: + logger.warning( + f"Length of module_input in attention hook ({name}) is {len(module_input)}, " + "expected >= 2. Skipping feature extraction for this module." + ) + return + q_h = module_input[0] + k_h = module_input[1] + qkt = cal_qkt(q_h, k_h, order=self.sa_order) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.attention', + f'{MonitorConst.NAME_SEP}{context.micro_step}', 'qkt', qkt) + ) + get_entropy_metric(tbtag_tensor_map, context.attention_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def extract_linear_sr_hook(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + weight_name = self.get_linear_hook_target(module) + if weight_name == '': + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + if context.micro_step == (self.micro_batch_number - 1): + tbtag_tensor_map = {} + value = getattr(module, weight_name).data + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.linear', + '', 'sr', value) + ) + get_sr_metric(tbtag_tensor_map, context.linear_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + def stack_hook(module, args, kwargs, module_output, name): if module not in self.module_fwd_hook_context_by_module: self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) @@ -1051,6 +1170,26 @@ class TrainerMon: self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) logger.info_on_rank_0(f"> {name} is monitored successfully") hooked_count += 1 + if not self.print_struct and self.recording_l2_features: + for module_name, submodule in module.named_modules(): + func_map = { + "attention_hook": extract_attention_feature_hook, + "linear_hook": extract_linear_sr_hook, + } + for hook_name in func_map.keys(): + if hook_name not in l2_target_names: + continue + temp_names = l2_target_names[hook_name] + name = self._is_recording_module(module_name, temp_names, vpp_stage, hook_name) + if name: + handle = submodule.register_forward_hook(partial(func_map[hook_name], name=name)) + print_feature_name = hook_name.split('_')[0] + logger.info_on_rank_0( + f'> {print_feature_name} features of {name} is monitored successfully') + self.handles["L2_features"].append(handle) + hooked_count += 1 + continue + return hooked_count def _patch_grad_sync(self): diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index c5730d7846ca915d524eb1c3ce76e805f321a652..24e9b43d6378e37f09d3c654511c840e9168f6f6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -17,6 +17,7 @@ import re import torch from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean +from msprobe.pytorch.monitor.features import cal_entropy, cal_stable_rank from msprobe.pytorch.monitor.utils import get_nan_tensor @@ -185,3 +186,27 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None): fun_metric = config_metric_registry.get(metric_name) out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps) return out_dict + + +def get_sr_metric(tag2tensor, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if "sr" not in tag: + continue + if tag not in out_dict: + out_dict[tag] = {} + sr, eig = cal_stable_rank(tensor) + out_dict[tag]['sr'] = sr + out_dict[tag]['kernel_norm'] = eig + + +def get_entropy_metric(tag2tensor, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if tag not in out_dict: + out_dict[tag] = {} + entropy, softmax_max = cal_entropy(tensor) + out_dict[tag]['entropy'] = entropy + out_dict[tag]['softmax_max'] = softmax_max diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py index ff00cf7490d8110f2198df57ee5d91b6b75f5092..a2e7ecb512d195305e48badd03c7e1b4ae30b237 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py @@ -1,8 +1,10 @@ import unittest +from unittest.mock import patch + import torch from msprobe.pytorch.monitor.features import square_sum, get_min, get_mean, get_norm, get_max, get_zeros, \ get_sign_matches, eff_rank, mNTK, lambda_max_subsample, cal_histc, get_nans - +from msprobe.pytorch.monitor.features import max_eigenvalue, cal_entropy, cal_qkt, cal_stable_rank class TestMathFunctions(unittest.TestCase): def test_square_sum(self): @@ -87,6 +89,74 @@ class TestMathFunctions(unittest.TestCase): result = get_nans(tensor) self.assertEqual(result, 1) + def test_max_eigenvalue(self): + """测试最大特征值计算""" + # 创建已知特征值的矩阵 + A = torch.diag(torch.tensor([3.0, 2.0, 1.0])) + + # 测试不同迭代次数 + eigval = max_eigenvalue(A, num_iterations=5) + self.assertAlmostEqual(eigval.item(), 3.0, delta=0.1) + + # 测试全零矩阵 + zero_matrix = torch.zeros(3, 3) + eigval = max_eigenvalue(zero_matrix) + self.assertAlmostEqual(eigval.item(), 0.0) + + def test_cal_entropy(self): + """测试注意力熵计算""" + # 创建简单的注意力分数 + qk = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]) + + # 无mask + entropy, softmax_max = cal_entropy(qk) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + # 带mask 和默认生成相同 + mask = torch.tensor([[1, 0, 0], + [1, 1, 0], + [1, 1, 1]], dtype=torch.float) + entropy, softmax_max = cal_entropy(qk, mask) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + @patch("msprobe.pytorch.monitor.features.logger") + def test_cal_qkt(self, mock_logger): + """测试QK^T计算""" + # 测试s,b,h,d顺序 + q = torch.randn(10, 2, 4, 8) # [s, b, h, d] + k = torch.randn(10, 2, 4, 8) # [s, b, h, d] + q_batch = torch.randn(2, 10, 4, 8) # [b, s, h, d] + qkt = cal_qkt(q, k, order="s,b,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试b,s,h,d顺序 + qkt = cal_qkt(q_batch, q_batch, order="b,s,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试无效顺序 + cal_qkt(q, k, order="invalid_order") + mock_logger.warning.assert_called_with( + "Calculate qk tensor failed: Order unsupported.") + + def test_cal_stable_rank(self): + """测试谱半径计算""" + # 创建已知谱半径的矩阵 + A = torch.diag(torch.tensor([3.0, 2.0, 1.0])) + sr, eig = cal_stable_rank(A) + + # 验证Frobenius范数 + fro_norm = torch.norm(A, p='fro') + self.assertAlmostEqual(sr, fro_norm / 3.0, delta=.5) # 最大特征值为3 + + # 测试正交矩阵 + ortho = torch.eye(5) + sr, eig = cal_stable_rank(ortho) + self.assertAlmostEqual(sr, torch.tensor(2.23/1), delta=.5) # F范数应为2.23 + self.assertAlmostEqual(eig, torch.tensor(1.0), delta=.1) # 特征值应为1 if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 83e8217c894d38b1d8506cb0e1cd241ffcbcb759..2f10d4f12906bfd91cfd304d157f5946dd68524b 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -8,7 +8,7 @@ from msprobe.core.common.const import MonitorConst from msprobe.core.monitor.utils import filter_special_chars, MsgConst, validate_ops, validate_ranks, \ validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ validate_mg_distribution, validate_wg_distribution, validate_cc_distribution, validate_alert, validate_config, \ - get_output_base_dir + get_output_base_dir, validate_l2_targets, validate_recording_l2_features, validate_sa_order from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.common.utils import is_recomputation @@ -112,6 +112,65 @@ class TestValidationFunctions(unittest.TestCase): self.assertEqual(config["targets"], {"": {}}) self.assertEqual(config["all_xy"], True) + # ===== validate_l2_targets 测试 ===== + def test_validate_l2_targets_valid_input(self): + """测试合法输入""" + valid_targets = { + "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], + "linear_hook": [] + } + validate_l2_targets(valid_targets) + + def test_validate_l2_targets_invalid_root_type(self): + """测试非 dict 输入""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets("not_a_dict") + self.assertEqual(str(cm.exception), + 'l2_targets in config.json should be a dict') + + def test_validate_l2_targets_invalid_hook_name(self): + """测试非法 hook_name""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"invalid_hook": ["module1"]}) + self.assertIn(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}', + str(cm.exception)) + + def test_validate_l2_targets_invalid_value_type(self): + """测试非法 value 类型""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"linear_hook": "not_a_list"}) + self.assertEqual(str(cm.exception), + 'values of l2_targets should be a list in config.json') + + def test_validate_l2_targets_invalid_item_type(self): + """测试非法 list item 类型""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"linear_hook": [123]}) + self.assertEqual(str(cm.exception), + 'item of "linear_hook" in l2_targets should be module_name[str] in config.json') + + # ===== validate_recording_l2_features 测试 ===== + def test_validate_recording_l2_features_valid(self): + """测试合法布尔值输入""" + validate_recording_l2_features(True) + validate_recording_l2_features(False) + + def test_validate_recording_l2_features_invalid_type(self): + """测试非法类型输入""" + with self.assertRaises(TypeError) as cm: + validate_recording_l2_features("xx") + self.assertEqual(str(cm.exception), + "recording_l2_features should be a bool") + + def test_valid_orders(self): + validate_sa_order("b,s,h,d") + validate_sa_order("s, b,h, d") + + def test_invalid_orders(self): + with self.assertRaises(TypeError) as cm: + validate_recording_l2_features("xx") + self.assertEqual(str(cm.exception), + f'sa_order must be in {MonitorConst.SA_ORDERS}, got xx') class TestIsRecomputation(unittest.TestCase): @patch('inspect.stack')