diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py
index 9c659a22df65cc4aa8b2843a950687e5e5fd82aa..a724c291060298fc7001adfc1037bdd419794c4c 100644
--- a/debug/accuracy_tools/msprobe/core/common/const.py
+++ b/debug/accuracy_tools/msprobe/core/common/const.py
@@ -797,6 +797,8 @@ class MonitorConst:
)
DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer"
RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan']
+ L2_HOOKS = ["linear_hook", "attention_hook"]
+ SA_ORDERS = ["s,b,h,d", "b,s,h,d"]
SLICE_SIZE = 20480
# used for name
diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py
index f19e14d89e6b7cb29d8bdc756a5d258081b106ab..2c7a20a040aa47762fec1b1ca37944fe64d1c793 100644
--- a/debug/accuracy_tools/msprobe/core/monitor/utils.py
+++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py
@@ -96,8 +96,33 @@ def validate_targets(targets):
raise TypeError('key of targets should be module_name[str] in config.json')
if not isinstance(field, dict):
raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
+
-
+def validate_l2_targets(targets):
+ if not isinstance(targets, dict):
+ raise TypeError('l2_targets in config.json should be a dict')
+ for hook_name, target_list in targets.items():
+ if hook_name not in MonitorConst.L2_HOOKS:
+ raise TypeError(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}, got {hook_name}')
+ if not isinstance(target_list, list):
+ raise TypeError('values of l2_targets should be a list in config.json')
+ for item in target_list:
+ if not isinstance(item, str):
+ raise TypeError(f'item of "{hook_name}" in l2_targets should be module_name[str] in config.json')
+
+
+def validate_recording_l2_features(recording_l2_features):
+ if not isinstance(recording_l2_features, bool):
+ raise TypeError("recording_l2_features should be a bool")
+
+
+def validate_sa_order(sa_order):
+ if isinstance(sa_order, str):
+ sa_order = sa_order.replace(' ', '')
+ if sa_order not in MonitorConst.SA_ORDERS:
+ raise TypeError(f'sa_order must be in {MonitorConst.SA_ORDERS}, got {sa_order}')
+
+
def validate_print_struct(print_struct):
if not isinstance(print_struct, bool):
raise TypeError("print_struct should be a bool")
@@ -216,6 +241,15 @@ def validate_config(config):
targets = config.get("targets", {})
validate_targets(targets)
+ l2_targets = config.get("l2_targets", {})
+ validate_l2_targets(l2_targets)
+
+ recording_l2_features = config.get("recording_l2_features", False)
+ validate_recording_l2_features(recording_l2_features)
+
+ sa_order = config.get("sa_order", "s,b,h,d")
+ validate_sa_order(sa_order)
+
print_struct = config.get('print_struct', False)
validate_print_struct(print_struct)
diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md
index 3d4be725da17a333a8bd19a3bc8069bc291ba247..69b9e5f23b23b3a58c559e2f6008167a7cde98b4 100644
--- a/debug/accuracy_tools/msprobe/docs/19.monitor.md
+++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md
@@ -24,6 +24,7 @@
| [采集module堆栈信息](#采集module堆栈信息) | 采集监控的第一个 step 的 module 对应的堆栈信息辅助问题定位 | PyTorch、MindSpore |
| [指定监控对象](#指定监控对象) | 指定监控的nn.Module(nn.Cell)及对应的输入输出 | PyTorch、MindSpore |
| [打印模型结构](#打印模型结构) | 打印模型结构 | PyTorch |
+| [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch |
| [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`、`ndigits`均支持 | PyTorch、MindSpore |
| [mbs粒度梯度监控](#mbs粒度梯度监控) | 开启梯度监控时,采集聚合前梯度时支持`micro_batch_size`粒度 | PyTorch、MindSpore |
| [异常告警](#异常告警) | 监控对象指标异常时自动告警,支持异常数据落盘 | PyTorch、MindSpore |
@@ -302,6 +303,34 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。
}
```
+### l2可解释特征监控
+- 工具配置示例
+```json
+{
+ "l2_targets": {
+ "attention_hook": ["0:0.self_attention.core_attention.flash_attention"],
+ "linear_hook": ["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"]
+ },
+ "recording_l2_features": true,
+ "sa_order": "b,s,h,d"
+}
+```
+| 配置项 | 类型 | 说明 | 是否必选 |
+|--------|------|------|--------|
+| **l2_targets** | Dict[str, List[str]] | 指定需要监控的模型层配置
**支持的hook类型**:
• `attention_hook`:监控注意力层
▪️ 采集指标:`entropy` `softmax_max`
▪️ 必须通过[打印模型结构](#打印模型结构)获取准确层名
▪️ 不配置或配置空列表均表示不采集
• `linear_hook`:监控线性层
▪️ 采集指标:`sr`, `kernel_norm`
▪️ 必须通过[打印模型结构](#打印模型结构)获取准确层名, 不配置表示不采集
▪️ 配置空列表会自动识别符合条件的层(包含`weight`或`wg`2D参数属性的层) | 是 |
+| **recording_l2_features** | bool | 是否开启L2层特征数据采集,默认为false表示不采集 | 否 |
+| **sa_order** | str | 计算`attention_hook`内指标时,指定Attention输入(Q,K)的张量维度排列顺序,支持"s,b,h,d"和"b,s,h,d", 默认为"s,b,h,d"表示输入维度顺序为**s**equence_len->**b**atch_size->num_**h**eads->head_**d**im | 否 |
+
+
+#### L2可解释特征监控指标说明
+
+| **指标名称** | **适用Hook类型** | **数学定义/计算方式** | **监控意义** |
+|--------------------|-------------------|-------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------|
+| **entropy** | attention_hook | $H(p)=-\sum p_i \log p_i$,其中$p_i$为注意力权重 | 衡量注意力分布的不确定性,**低熵值**表示注意力集中 |
+| **softmax_max** | attention_hook | $\max(\text{softmax}(QK^T/\sqrt{d}))$ | 反映注意力机制的聚焦程度,**高值**表示存在显著主导的注意力token |
+| **sr(stable_rank)** | linear_hook | $\frac{\|W\|_F}{\|W\|_2}$(稳定秩,Frobenius范数除以谱范数) | 评估权重矩阵的有效秩,**低值**表示矩阵接近低秩不稳定状态 |
+| **kernel_norm** | linear_hook | $\|W\|_F$(Frobenius范数) | 权重矩阵的缩谱范数,反映输入在矩阵最大奇异向量张成空间的放大系数 |
+
### 输出格式和统计量
diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py
index cfd0c1615d2072147d12e7e279527bef86c9d497..960f3dabe32acb8c8ed3018f7421747131e02984 100644
--- a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py
+++ b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py
@@ -111,3 +111,97 @@ def cal_histc(tensor_cal, bins_total, min_val, max_val):
@torch.no_grad()
def get_nans(t):
return torch.isnan(t).sum()
+
+
+def check_tensor_dim(tensor, n):
+ """检查张量维度是否大于n
+ """
+ if not isinstance(tensor, torch.Tensor):
+ raise TypeError(
+ f"Input must be a PyTorch tensor. Got {type(tensor)} instead. "
+ f"Consider using torch.tensor() for conversion."
+ )
+
+ if tensor.dim() < n:
+ raise ValueError(
+ f"Tensor must have at least {n} dimensions. "
+ f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims."
+ )
+
+
+@torch.no_grad()
+def max_eigenvalue(input_tensor: torch.Tensor, num_iterations=3):
+ input_tensor = input_tensor.float()
+ try:
+ check_tensor_dim(input_tensor, 2)
+ except (TypeError, ValueError) as e:
+ logger.warning(f"Calculate max eigenvalue failed: {e}")
+ return torch.tensor(0)
+ in_features = input_tensor.shape[1]
+ u_tensor = torch.randn(in_features).to(input_tensor.device)
+ u_norm = u_tensor.norm()
+ if u_norm.item() == 0:
+ return torch.tensor(0)
+ u_tensor = u_tensor / u_tensor.norm()
+ input_seq = torch.matmul(input_tensor.T, input_tensor)
+ for _ in range(num_iterations):
+ v_tensor = torch.matmul(input_seq, u_tensor)
+ spectral_norm = torch.matmul(v_tensor.T, u_tensor)
+ v_norm = v_tensor.norm()
+ if v_norm > 0:
+ u_tensor = v_tensor / v_norm
+ else:
+ spectral_norm = torch.tensor(0)
+ break
+ return spectral_norm.sqrt()
+
+
+@torch.no_grad()
+def cal_entropy(qk_tensor, mask=None):
+ try:
+ check_tensor_dim(qk_tensor, 2)
+ except (TypeError, ValueError) as e:
+ logger.warning(f"Calculate max eigenvalue failed: {e}")
+ return torch.tensor(0), torch.tensor(0)
+ if mask is None:
+ mask = torch.tril(torch.ones(qk_tensor.shape[1], qk_tensor.shape[1])).to(
+ qk_tensor.device)
+ qk_tensor = qk_tensor - torch.amax(qk_tensor, dim=1, keepdim=True)
+ qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf'))
+ softmax_qkt = torch.nn.functional.softmax(qk_tensor.float(), dim=1)
+ # softmax取QK矩阵最大值
+ softmax_max = torch.mean(torch.amax(softmax_qkt, dim=1))
+ entropy = torch.mean(-torch.nansum(softmax_qkt *
+ torch.log(softmax_qkt), dim=1))
+ return entropy, softmax_max
+
+
+@torch.no_grad()
+def cal_qkt(q_h, k_h, order="s,b,h,d"):
+ # q_h shape is [s, b, h, d]
+ try:
+ check_tensor_dim(q_h, 4)
+ check_tensor_dim(k_h, 4)
+ except (TypeError, ValueError) as e:
+ logger.warning(f"Calculate qk tensor failed: {e}")
+ return torch.tensor(0)
+
+ if order == "s,b,h,d":
+ qkt = torch.matmul(
+ q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5
+ elif order == "b,s,h,d":
+ qkt = torch.matmul(
+ q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5
+ else:
+ logger.warning("Calculate qk tensor failed: Order unsupported.")
+ qkt = torch.tensor(0)
+ return qkt
+
+
+@torch.no_grad()
+def cal_stable_rank(weight: torch.Tensor):
+ eig = max_eigenvalue(weight)
+ if eig == torch.tensor(0):
+ return torch.tensor(0), torch.tensor(0)
+ f_norm = torch.norm(weight, p="fro")
+ return f_norm / eig, eig
diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py
index f39cd7a83384d4f9b3342d77a0f6183b051c7089..0a2cd447c9fc32fb8e6ebcb2aa1baab05577c637 100644
--- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py
+++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py
@@ -40,9 +40,9 @@ from msprobe.pytorch.monitor.utils import get_param_struct
from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
get_process_group
-from msprobe.pytorch.monitor.features import get_sign_matches
+from msprobe.pytorch.monitor.features import get_sign_matches, cal_qkt
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
- TensorMetrics, squash_param_name
+ TensorMetrics, squash_param_name, get_entropy_metric, get_sr_metric
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
@@ -57,6 +57,7 @@ FORMAT_MAPPING = {
MonitorConst.CSV: CSVWriterWithAD,
MonitorConst.API: BaseWriterWithAD
}
+start_step = 0
def param_is_not_tensor_parallel_duplicate(param, tp_group):
@@ -83,7 +84,17 @@ class ModuleHookContext:
self.actvgrad.clear()
-start_step = 0
+class FeatureHookContext:
+ def __init__(self, module_name):
+ self.step = 0
+ self.micro_step = 0
+ self.attention_feature = {}
+ self.linear_feature = {}
+ self.module_name = module_name
+
+ def reset(self):
+ self.attention_feature.clear()
+ self.linear_feature.clear()
class OptimizerContext:
@@ -206,6 +217,7 @@ class TrainerMon:
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
+ self.feature_hook_context_by_module = defaultdict(FeatureHookContext)
self.optimizer_context = defaultdict(OptimizerContext)
self.cc_context = defaultdict(CommunicationContext)
self.grad_context = GradContext()
@@ -274,6 +286,18 @@ class TrainerMon:
cc_tensor.reset()
return metrics
+ @staticmethod
+ def get_linear_hook_target(module):
+ if isinstance(module, torch.nn.Embedding):
+ return ''
+ if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"):
+ return ''
+ for weight_name in ["weight", "wg"]:
+ if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor):
+ if getattr(module, weight_name).dim() == 2:
+ return weight_name
+ return ''
+
def set_config(self):
logger.info(f"current config: {self.config}")
self.start_step = self.config.get("start_step", 0)
@@ -298,6 +322,8 @@ class TrainerMon:
self.cc_distribution = self.config.get("cc_distribution", {})
self.stack_info = self.config.get('stack_info', False)
self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
+ self.recording_l2_features = self.config.get("recording_l2_features", False)
+ self.sa_order = self.config.get("sa_order", "s,b,h,d")
if not self.cc_distribution.get('enable', False):
self.cc_log_only = False
@@ -356,6 +382,8 @@ class TrainerMon:
logger.info_on_rank_0("> momentum and variance of adam is not monitored. ")
if not self.wg_distribution:
logger.info_on_rank_0("> weight grad of specified module is not monitored. ")
+ if not self.recording_l2_features:
+ logger.info_on_rank_0("> l2 features of specified module is not monitored. ")
if not self.mg_direction:
logger.info_on_rank_0('> grad and momentum direction will not be compared.')
if not self.cc_distribution.get('enable', False):
@@ -537,6 +565,24 @@ class TrainerMon:
if self.grad_context.actv:
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
+ def write_metrics_if_not_empty(self, features, metrics, step, hook_name):
+ if not features or len(features) == 0:
+ return
+ use_micro_step = hook_name not in ["linear_hook"]
+ self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=use_micro_step)
+ features.clear()
+
+ def write_features_tb(self, step):
+ if not self.recording_l2_features:
+ return
+ for context in self.feature_hook_context_by_module.values():
+ num_features = len(context.attention_feature) + len(context.linear_feature)
+ if num_features == 0:
+ continue
+ self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax_max"],
+ step, "attention_hook")
+ self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step, "linear_hook")
+
def write_param_tb(self, opt_context):
if not self.param_distribution:
return
@@ -691,6 +737,7 @@ class TrainerMon:
if self.anomaly_data_factory:
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
self.write_xy_tb(context.step)
+ self.write_features_tb(context.step)
self.write_grad_tb(context.step)
self.write_mv_tb(context)
self.write_param_tb(context)
@@ -760,7 +807,8 @@ class TrainerMon:
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
'targets'].keys()
- hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
+ l2_target_names = self.config.get('l2_targets', '')
+ hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage)
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
@@ -801,6 +849,9 @@ class TrainerMon:
for handle in self.handles['xy']:
handle.remove()
self.handles['xy'].clear()
+ for handle in self.handles['L2_features']:
+ handle.remove()
+ self.handles['L2_features'].clear()
# 清空对应context缓存
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
fwd_context.reset()
@@ -941,7 +992,20 @@ class TrainerMon:
return pattern
return ""
- def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''):
+ def _is_recording_module(self, module_name, l2_targets, vpp_stage, hook_name):
+
+ if len(l2_targets) > 0:
+ for pattern in [
+ vpp_stage + squash_param_name(module_name, self.squash_name),
+ vpp_stage + module_name,
+ ]:
+ if pattern in l2_targets:
+ return pattern
+ elif hook_name in ["linear_hook"]:
+ return vpp_stage + squash_param_name(module_name, self.squash_name)
+ return ""
+
+ def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
if '_modules' not in module.__dict__:
# nothing to hook
return 0
@@ -1020,6 +1084,61 @@ class TrainerMon:
context.micro_step = 0
return
+ def extract_attention_feature_hook(module, module_input, module_output, name):
+ if is_recomputation() or not module.training:
+ return
+
+ if module not in self.feature_hook_context_by_module:
+ self.feature_hook_context_by_module[module] = FeatureHookContext(name)
+ context: FeatureHookContext = self.feature_hook_context_by_module[module]
+ tbtag_tensor_map = {}
+ if len(module_input) < 2:
+ logger.warning(
+ f"Length of module_input in attention hook ({name}) is {len(module_input)}, "
+ "expected >= 2. Skipping feature extraction for this module."
+ )
+ return
+ q_h = module_input[0]
+ k_h = module_input[1]
+ qkt = cal_qkt(q_h, k_h, order=self.sa_order)
+ tbtag_tensor_map.update(
+ self.build_tbtag_tensor_map(f'{context.module_name}.attention',
+ f'{MonitorConst.NAME_SEP}{context.micro_step}', 'qkt', qkt)
+ )
+ get_entropy_metric(tbtag_tensor_map, context.attention_feature)
+
+ context.micro_step += 1
+ if context.micro_step == self.micro_batch_number:
+ context.micro_step = 0
+ context.step += 1
+ return
+
+ def extract_linear_sr_hook(module, module_input, module_output, name):
+ if is_recomputation() or not module.training:
+ return
+ weight_name = self.get_linear_hook_target(module)
+ if weight_name == '':
+ return
+
+ if module not in self.feature_hook_context_by_module:
+ self.feature_hook_context_by_module[module] = FeatureHookContext(name)
+ context: FeatureHookContext = self.feature_hook_context_by_module[module]
+
+ if context.micro_step == (self.micro_batch_number - 1):
+ tbtag_tensor_map = {}
+ value = getattr(module, weight_name).data
+ tbtag_tensor_map.update(
+ self.build_tbtag_tensor_map(f'{context.module_name}.linear',
+ '', 'sr', value)
+ )
+ get_sr_metric(tbtag_tensor_map, context.linear_feature)
+
+ context.micro_step += 1
+ if context.micro_step == self.micro_batch_number:
+ context.micro_step = 0
+ context.step += 1
+ return
+
def stack_hook(module, args, kwargs, module_output, name):
if module not in self.module_fwd_hook_context_by_module:
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
@@ -1051,6 +1170,26 @@ class TrainerMon:
self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
logger.info_on_rank_0(f"> {name} is monitored successfully")
hooked_count += 1
+ if not self.print_struct and self.recording_l2_features:
+ for module_name, submodule in module.named_modules():
+ func_map = {
+ "attention_hook": extract_attention_feature_hook,
+ "linear_hook": extract_linear_sr_hook,
+ }
+ for hook_name in func_map.keys():
+ if hook_name not in l2_target_names:
+ continue
+ temp_names = l2_target_names[hook_name]
+ name = self._is_recording_module(module_name, temp_names, vpp_stage, hook_name)
+ if name:
+ handle = submodule.register_forward_hook(partial(func_map[hook_name], name=name))
+ print_feature_name = hook_name.split('_')[0]
+ logger.info_on_rank_0(
+ f'> {print_feature_name} features of {name} is monitored successfully')
+ self.handles["L2_features"].append(handle)
+ hooked_count += 1
+ continue
+
return hooked_count
def _patch_grad_sync(self):
diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py
index c5730d7846ca915d524eb1c3ce76e805f321a652..24e9b43d6378e37f09d3c654511c840e9168f6f6 100644
--- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py
+++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py
@@ -17,6 +17,7 @@ import re
import torch
from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
+from msprobe.pytorch.monitor.features import cal_entropy, cal_stable_rank
from msprobe.pytorch.monitor.utils import get_nan_tensor
@@ -185,3 +186,27 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
fun_metric = config_metric_registry.get(metric_name)
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
return out_dict
+
+
+def get_sr_metric(tag2tensor, out_dict=None):
+ if out_dict is None:
+ out_dict = {}
+ for tag, tensor in tag2tensor.items():
+ if "sr" not in tag:
+ continue
+ if tag not in out_dict:
+ out_dict[tag] = {}
+ sr, eig = cal_stable_rank(tensor)
+ out_dict[tag]['sr'] = sr
+ out_dict[tag]['kernel_norm'] = eig
+
+
+def get_entropy_metric(tag2tensor, out_dict=None):
+ if out_dict is None:
+ out_dict = {}
+ for tag, tensor in tag2tensor.items():
+ if tag not in out_dict:
+ out_dict[tag] = {}
+ entropy, softmax_max = cal_entropy(tensor)
+ out_dict[tag]['entropy'] = entropy
+ out_dict[tag]['softmax_max'] = softmax_max
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py
index ff00cf7490d8110f2198df57ee5d91b6b75f5092..a2e7ecb512d195305e48badd03c7e1b4ae30b237 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py
@@ -1,8 +1,10 @@
import unittest
+from unittest.mock import patch
+
import torch
from msprobe.pytorch.monitor.features import square_sum, get_min, get_mean, get_norm, get_max, get_zeros, \
get_sign_matches, eff_rank, mNTK, lambda_max_subsample, cal_histc, get_nans
-
+from msprobe.pytorch.monitor.features import max_eigenvalue, cal_entropy, cal_qkt, cal_stable_rank
class TestMathFunctions(unittest.TestCase):
def test_square_sum(self):
@@ -87,6 +89,74 @@ class TestMathFunctions(unittest.TestCase):
result = get_nans(tensor)
self.assertEqual(result, 1)
+ def test_max_eigenvalue(self):
+ """测试最大特征值计算"""
+ # 创建已知特征值的矩阵
+ A = torch.diag(torch.tensor([3.0, 2.0, 1.0]))
+
+ # 测试不同迭代次数
+ eigval = max_eigenvalue(A, num_iterations=5)
+ self.assertAlmostEqual(eigval.item(), 3.0, delta=0.1)
+
+ # 测试全零矩阵
+ zero_matrix = torch.zeros(3, 3)
+ eigval = max_eigenvalue(zero_matrix)
+ self.assertAlmostEqual(eigval.item(), 0.0)
+
+ def test_cal_entropy(self):
+ """测试注意力熵计算"""
+ # 创建简单的注意力分数
+ qk = torch.tensor([[1.0, 2.0, 3.0],
+ [4.0, 5.0, 6.0],
+ [7.0, 8.0, 9.0]])
+
+ # 无mask
+ entropy, softmax_max = cal_entropy(qk)
+ self.assertAlmostEqual(entropy, 0.4715, delta=0.1)
+ self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1)
+
+ # 带mask 和默认生成相同
+ mask = torch.tensor([[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]], dtype=torch.float)
+ entropy, softmax_max = cal_entropy(qk, mask)
+ self.assertAlmostEqual(entropy, 0.4715, delta=0.1)
+ self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1)
+
+ @patch("msprobe.pytorch.monitor.features.logger")
+ def test_cal_qkt(self, mock_logger):
+ """测试QK^T计算"""
+ # 测试s,b,h,d顺序
+ q = torch.randn(10, 2, 4, 8) # [s, b, h, d]
+ k = torch.randn(10, 2, 4, 8) # [s, b, h, d]
+ q_batch = torch.randn(2, 10, 4, 8) # [b, s, h, d]
+ qkt = cal_qkt(q, k, order="s,b,h,d")
+ self.assertEqual(qkt.shape, (10, 10)) # [s, s]
+
+ # 测试b,s,h,d顺序
+ qkt = cal_qkt(q_batch, q_batch, order="b,s,h,d")
+ self.assertEqual(qkt.shape, (10, 10)) # [s, s]
+
+ # 测试无效顺序
+ cal_qkt(q, k, order="invalid_order")
+ mock_logger.warning.assert_called_with(
+ "Calculate qk tensor failed: Order unsupported.")
+
+ def test_cal_stable_rank(self):
+ """测试谱半径计算"""
+ # 创建已知谱半径的矩阵
+ A = torch.diag(torch.tensor([3.0, 2.0, 1.0]))
+ sr, eig = cal_stable_rank(A)
+
+ # 验证Frobenius范数
+ fro_norm = torch.norm(A, p='fro')
+ self.assertAlmostEqual(sr, fro_norm / 3.0, delta=.5) # 最大特征值为3
+
+ # 测试正交矩阵
+ ortho = torch.eye(5)
+ sr, eig = cal_stable_rank(ortho)
+ self.assertAlmostEqual(sr, torch.tensor(2.23/1), delta=.5) # F范数应为2.23
+ self.assertAlmostEqual(eig, torch.tensor(1.0), delta=.1) # 特征值应为1
if __name__ == '__main__':
unittest.main()
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py
index 83e8217c894d38b1d8506cb0e1cd241ffcbcb759..2f10d4f12906bfd91cfd304d157f5946dd68524b 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py
@@ -8,7 +8,7 @@ from msprobe.core.common.const import MonitorConst
from msprobe.core.monitor.utils import filter_special_chars, MsgConst, validate_ops, validate_ranks, \
validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \
validate_mg_distribution, validate_wg_distribution, validate_cc_distribution, validate_alert, validate_config, \
- get_output_base_dir
+ get_output_base_dir, validate_l2_targets, validate_recording_l2_features, validate_sa_order
from msprobe.pytorch.monitor.utils import get_param_struct
from msprobe.pytorch.common.utils import is_recomputation
@@ -112,6 +112,65 @@ class TestValidationFunctions(unittest.TestCase):
self.assertEqual(config["targets"], {"": {}})
self.assertEqual(config["all_xy"], True)
+ # ===== validate_l2_targets 测试 =====
+ def test_validate_l2_targets_valid_input(self):
+ """测试合法输入"""
+ valid_targets = {
+ "attention_hook": ["0:0.self_attention.core_attention.flash_attention"],
+ "linear_hook": []
+ }
+ validate_l2_targets(valid_targets)
+
+ def test_validate_l2_targets_invalid_root_type(self):
+ """测试非 dict 输入"""
+ with self.assertRaises(TypeError) as cm:
+ validate_l2_targets("not_a_dict")
+ self.assertEqual(str(cm.exception),
+ 'l2_targets in config.json should be a dict')
+
+ def test_validate_l2_targets_invalid_hook_name(self):
+ """测试非法 hook_name"""
+ with self.assertRaises(TypeError) as cm:
+ validate_l2_targets({"invalid_hook": ["module1"]})
+ self.assertIn(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}',
+ str(cm.exception))
+
+ def test_validate_l2_targets_invalid_value_type(self):
+ """测试非法 value 类型"""
+ with self.assertRaises(TypeError) as cm:
+ validate_l2_targets({"linear_hook": "not_a_list"})
+ self.assertEqual(str(cm.exception),
+ 'values of l2_targets should be a list in config.json')
+
+ def test_validate_l2_targets_invalid_item_type(self):
+ """测试非法 list item 类型"""
+ with self.assertRaises(TypeError) as cm:
+ validate_l2_targets({"linear_hook": [123]})
+ self.assertEqual(str(cm.exception),
+ 'item of "linear_hook" in l2_targets should be module_name[str] in config.json')
+
+ # ===== validate_recording_l2_features 测试 =====
+ def test_validate_recording_l2_features_valid(self):
+ """测试合法布尔值输入"""
+ validate_recording_l2_features(True)
+ validate_recording_l2_features(False)
+
+ def test_validate_recording_l2_features_invalid_type(self):
+ """测试非法类型输入"""
+ with self.assertRaises(TypeError) as cm:
+ validate_recording_l2_features("xx")
+ self.assertEqual(str(cm.exception),
+ "recording_l2_features should be a bool")
+
+ def test_valid_orders(self):
+ validate_sa_order("b,s,h,d")
+ validate_sa_order("s, b,h, d")
+
+ def test_invalid_orders(self):
+ with self.assertRaises(TypeError) as cm:
+ validate_recording_l2_features("xx")
+ self.assertEqual(str(cm.exception),
+ f'sa_order must be in {MonitorConst.SA_ORDERS}, got xx')
class TestIsRecomputation(unittest.TestCase):
@patch('inspect.stack')